diff options
author | Will Hunt <will@half-shot.uk> | 2018-07-09 13:31:21 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-07-09 13:31:21 +0000 |
commit | a7f4ebbd3e786ce2e156127a3806aa2594233542 (patch) | |
tree | c73795182b60fcc17a5e5b13864eb84758d96c29 /tests | |
parent | /limits => /config (diff) | |
parent | Merge pull request #3464 from matrix-org/hawkowl/isort-run (diff) | |
download | synapse-a7f4ebbd3e786ce2e156127a3806aa2594233542.tar.xz |
Merge branch 'develop' into hs/upload-limits
Diffstat (limited to 'tests')
80 files changed, 1628 insertions, 469 deletions
diff --git a/tests/__init__.py b/tests/__init__.py index bfebb0f644..24006c949e 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -12,3 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from twisted.trial import util + +util.DEFAULT_TIMEOUT_DURATION = 10 diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4575dd9834..5f158ec4b9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -13,16 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pymacaroons from mock import Mock + +import pymacaroons + from twisted.internet import defer import synapse.handlers.auth from synapse.api.auth import Auth from synapse.api.errors import AuthError from synapse.types import UserID + from tests import unittest -from tests.utils import setup_test_homeserver, mock_getRawHeaders +from tests.utils import mock_getRawHeaders, setup_test_homeserver class TestHandlers(object): @@ -86,16 +89,53 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) + + @defer.inlineCallbacks + def test_get_user_by_req_appservice_valid_token_good_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "192.168.10.10" request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) + def test_get_user_by_req_appservice_valid_token_bad_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "131.111.8.42" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + d = self.auth.get_user_by_req(request) + self.failureResultOf(d, AuthError) + def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) @@ -119,12 +159,16 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -133,12 +177,16 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index dcceca7f3e..836a23fb54 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -13,19 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from tests.utils import ( - MockHttpResource, DeferredMockCallable, setup_test_homeserver -) +import jsonschema + +from twisted.internet import defer + +from synapse.api.errors import SynapseError from synapse.api.filtering import Filter from synapse.events import FrozenEvent -from synapse.api.errors import SynapseError -import jsonschema +from tests import unittest +from tests.utils import DeferredMockCallable, MockHttpResource, setup_test_homeserver user_localpart = "test_user" diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 5b2b95860a..891e0cc973 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,14 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.appservice import ApplicationService +import re + +from mock import Mock from twisted.internet import defer -from mock import Mock -from tests import unittest +from synapse.appservice import ApplicationService -import re +from tests import unittest def _regex(regex, exclusive=True): diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e5a902f734..b9f4863e9a 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,15 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + +from twisted.internet import defer + from synapse.appservice import ApplicationServiceState from synapse.appservice.scheduler import ( - _ServiceQueuer, _TransactionController, _Recoverer + _Recoverer, + _ServiceQueuer, + _TransactionController, ) -from twisted.internet import defer -from ..utils import MockClock -from mock import Mock +from synapse.util.logcontext import make_deferred_yieldable + from tests import unittest +from ..utils import MockClock + class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): @@ -204,7 +211,9 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): def test_send_single_event_with_queue(self): d = defer.Deferred() - self.txn_ctrl.send = Mock(return_value=d) + self.txn_ctrl.send = Mock( + side_effect=lambda x, y: make_deferred_yieldable(d), + ) service = Mock(id=4) event = Mock(event_id="first") event2 = Mock(event_id="second") @@ -235,7 +244,10 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase): srv_2_event2 = Mock(event_id="srv2b") send_return_list = [srv_1_defer, srv_2_defer] - self.txn_ctrl.send = Mock(side_effect=lambda x, y: send_return_list.pop(0)) + + def do_send(x, y): + return make_deferred_yieldable(send_return_list.pop(0)) + self.txn_ctrl.send = Mock(side_effect=do_send) # send events for different ASes and make sure they are sent self.queuer.enqueue(srv1, srv_1_event) diff --git a/tests/config/test_generate.py b/tests/config/test_generate.py index 879159ccea..eb7f0ab12a 100644 --- a/tests/config/test_generate.py +++ b/tests/config/test_generate.py @@ -19,6 +19,7 @@ import shutil import tempfile from synapse.config.homeserver import HomeServerConfig + from tests import unittest diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 772afd2cf9..5c422eff38 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -15,8 +15,11 @@ import os.path import shutil import tempfile + import yaml + from synapse.config.homeserver import HomeServerConfig + from tests import unittest diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 47cb328a01..cd11871b80 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -14,15 +14,13 @@ # limitations under the License. -from tests import unittest - -from synapse.events.builder import EventBuilder -from synapse.crypto.event_signing import add_hashes_and_signatures - +import nacl.signing from unpaddedbase64 import decode_base64 -import nacl.signing +from synapse.crypto.event_signing import add_hashes_and_signatures +from synapse.events.builder import EventBuilder +from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 149e443022..a9d37fe084 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -14,15 +14,19 @@ # limitations under the License. import time +from mock import Mock + import signedjson.key import signedjson.sign -from mock import Mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError from synapse.crypto import keyring -from synapse.util import async, logcontext +from synapse.util import Clock, logcontext from synapse.util.logcontext import LoggingContext + from tests import unittest, utils -from twisted.internet import defer class MockPerspectiveServer(object): @@ -118,6 +122,7 @@ class KeyringTestCase(unittest.TestCase): @defer.inlineCallbacks def test_verify_json_objects_for_server_awaits_previous_requests(self): + clock = Clock(reactor) key1 = signedjson.key.generate_signing_key(1) kr = keyring.Keyring(self.hs) @@ -167,7 +172,7 @@ class KeyringTestCase(unittest.TestCase): # wait a tick for it to send the request to the perspectives server # (it first tries the datastore) - yield async.sleep(1) # XXX find out why this takes so long! + yield clock.sleep(1) # XXX find out why this takes so long! self.http_client.post_json.assert_called_once() self.assertIs(LoggingContext.current_context(), context_11) @@ -183,7 +188,7 @@ class KeyringTestCase(unittest.TestCase): res_deferreds_2 = kr.verify_json_objects_for_server( [("server10", json1)], ) - yield async.sleep(1) + yield clock.sleep(1) self.http_client.post_json.assert_not_called() res_deferreds_2[0].addBoth(self.check_context, None) diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index dfc870066e..f51d99419e 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -14,11 +14,11 @@ # limitations under the License. -from .. import unittest - from synapse.events import FrozenEvent from synapse.events.utils import prune_event, serialize_event +from .. import unittest + def MockEvent(**kwargs): if "event_id" not in kwargs: diff --git a/tests/metrics/__init__.py b/tests/federation/__init__.py index e69de29bb2..e69de29bb2 100644 --- a/tests/metrics/__init__.py +++ b/tests/federation/__init__.py diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py new file mode 100644 index 0000000000..c91e25f54f --- /dev/null +++ b/tests/federation/test_federation_server.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.events import FrozenEvent +from synapse.federation.federation_server import server_matches_acl_event + +from tests import unittest + + +@unittest.DEBUG +class ServerACLsTestCase(unittest.TestCase): + def test_blacklisted_server(self): + e = _create_acl_event({ + "allow": ["*"], + "deny": ["evil.com"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("evil.com", e)) + self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + + self.assertTrue(server_matches_acl_event("evil.com.au", e)) + self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + + def test_block_ip_literals(self): + e = _create_acl_event({ + "allow_ip_literals": False, + "allow": ["*"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("1.2.3.4", e)) + self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) + self.assertFalse(server_matches_acl_event("[1:2::]", e)) + self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + + +def _create_acl_event(content): + return FrozenEvent({ + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content + }) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b753455943..57c0771cf3 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,13 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest -from tests.utils import MockClock from synapse.handlers.appservice import ApplicationServicesHandler -from mock import Mock +from tests.utils import MockClock + +from .. import unittest class AppServiceHandlerTestCase(unittest.TestCase): diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 1822dcf1e0..2e5e8e4dec 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -14,11 +14,13 @@ # limitations under the License. import pymacaroons + from twisted.internet import defer import synapse import synapse.api.errors from synapse.handlers.auth import AuthHandler + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index 778ff2f6e9..633a0b7f36 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -17,8 +17,8 @@ from twisted.internet import defer import synapse.api.errors import synapse.handlers.device - import synapse.storage + from tests import unittest, utils user1 = "@boris:aaa" diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 7e5332e272..a353070316 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -14,14 +14,14 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock +from twisted.internet import defer + from synapse.handlers.directory import DirectoryHandler from synapse.types import RoomAlias +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index d1bd87b898..ca1542236d 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -14,13 +14,14 @@ # limitations under the License. import mock -from synapse.api import errors + from twisted.internet import defer import synapse.api.errors import synapse.handlers.e2e_keys - import synapse.storage +from synapse.api import errors + from tests import unittest, utils diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index de06a6ad30..121ce78634 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -14,18 +14,22 @@ # limitations under the License. -from tests import unittest - from mock import Mock, call from synapse.api.constants import PresenceState from synapse.handlers.presence import ( - handle_update, handle_timeout, - IDLE_TIMER, SYNC_ONLINE_TIMEOUT, LAST_ACTIVE_GRANULARITY, FEDERATION_TIMEOUT, FEDERATION_PING_INTERVAL, + FEDERATION_TIMEOUT, + IDLE_TIMER, + LAST_ACTIVE_GRANULARITY, + SYNC_ONLINE_TIMEOUT, + handle_timeout, + handle_update, ) from synapse.storage.presence import UserPresenceState +from tests import unittest + class PresenceUpdateTestCase(unittest.TestCase): def test_offline_to_online(self): diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 458296ee4c..dc17918a3d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock, NonCallableMock +from twisted.internet import defer + import synapse.types from synapse.api.errors import AuthError from synapse.handlers.profile import ProfileHandler from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e990e45220..025fa1be81 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer -from .. import unittest from synapse.handlers.register import RegistrationHandler from synapse.types import UserID, create_requester from tests.utils import setup_test_homeserver -from mock import Mock +from .. import unittest class RegistrationHandlers(object): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index a433bbfa8a..b08856f763 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -14,19 +14,24 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - -from mock import Mock, call, ANY import json -from ..utils import ( - MockHttpResource, MockClock, DeferredMockCallable, setup_test_homeserver -) +from mock import ANY, Mock, call + +from twisted.internet import defer from synapse.api.errors import AuthError from synapse.types import UserID +from tests import unittest + +from ..utils import ( + DeferredMockCallable, + MockClock, + MockHttpResource, + setup_test_homeserver, +) + def _expect_edu(destination, edu_type, content, origin="test"): return { diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/http/__init__.py diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py new file mode 100644 index 0000000000..60e6a75953 --- /dev/null +++ b/tests/http/test_endpoint.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name + +from tests import unittest + + +class ServerNameTestCase(unittest.TestCase): + def test_parse_server_name(self): + test_data = { + 'localhost': ('localhost', None), + 'my-example.com:1234': ('my-example.com', 1234), + '1.2.3.4': ('1.2.3.4', None), + '[0abc:1def::1234]': ('[0abc:1def::1234]', None), + '1.2.3.4:1': ('1.2.3.4', 1), + '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + } + + for i, o in test_data.items(): + self.assertEqual(parse_server_name(i), o) + + def test_validate_bad_server_names(self): + test_data = [ + "", # empty + "localhost:http", # non-numeric port + "1234]", # smells like ipv6 literal but isn't + "[1234", + "underscore_.com", + "percent%65.com", + "1234:5678:80", # too many colons + ] + for i in test_data: + try: + parse_and_validate_server_name(i) + self.fail( + "Expected parse_and_validate_server_name('%s') to throw" % ( + i, + ), + ) + except ValueError: + pass diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py deleted file mode 100644 index 39bde6e3f8..0000000000 --- a/tests/metrics/test_metric.py +++ /dev/null @@ -1,173 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from tests import unittest - -from synapse.metrics.metric import ( - CounterMetric, CallbackMetric, DistributionMetric, CacheMetric -) - - -class CounterMetricTestCase(unittest.TestCase): - - def test_scalar(self): - counter = CounterMetric("scalar") - - self.assertEquals(counter.render(), [ - 'scalar 0', - ]) - - counter.inc() - - self.assertEquals(counter.render(), [ - 'scalar 1', - ]) - - counter.inc_by(2) - - self.assertEquals(counter.render(), [ - 'scalar 3' - ]) - - def test_vector(self): - counter = CounterMetric("vector", labels=["method"]) - - # Empty counter doesn't yet know what values it has - self.assertEquals(counter.render(), []) - - counter.inc("GET") - - self.assertEquals(counter.render(), [ - 'vector{method="GET"} 1', - ]) - - counter.inc("GET") - counter.inc("PUT") - - self.assertEquals(counter.render(), [ - 'vector{method="GET"} 2', - 'vector{method="PUT"} 1', - ]) - - -class CallbackMetricTestCase(unittest.TestCase): - - def test_scalar(self): - d = dict() - - metric = CallbackMetric("size", lambda: len(d)) - - self.assertEquals(metric.render(), [ - 'size 0', - ]) - - d["key"] = "value" - - self.assertEquals(metric.render(), [ - 'size 1', - ]) - - def test_vector(self): - vals = dict() - - metric = CallbackMetric("values", lambda: vals, labels=["type"]) - - self.assertEquals(metric.render(), []) - - # Keys have to be tuples, even if they're 1-element - vals[("foo",)] = 1 - vals[("bar",)] = 2 - - self.assertEquals(metric.render(), [ - 'values{type="bar"} 2', - 'values{type="foo"} 1', - ]) - - -class DistributionMetricTestCase(unittest.TestCase): - - def test_scalar(self): - metric = DistributionMetric("thing") - - self.assertEquals(metric.render(), [ - 'thing:count 0', - 'thing:total 0', - ]) - - metric.inc_by(500) - - self.assertEquals(metric.render(), [ - 'thing:count 1', - 'thing:total 500', - ]) - - def test_vector(self): - metric = DistributionMetric("queries", labels=["verb"]) - - self.assertEquals(metric.render(), []) - - metric.inc_by(300, "SELECT") - metric.inc_by(200, "SELECT") - metric.inc_by(800, "INSERT") - - self.assertEquals(metric.render(), [ - 'queries:count{verb="INSERT"} 1', - 'queries:count{verb="SELECT"} 2', - 'queries:total{verb="INSERT"} 800', - 'queries:total{verb="SELECT"} 500', - ]) - - -class CacheMetricTestCase(unittest.TestCase): - - def test_cache(self): - d = dict() - - metric = CacheMetric("cache", lambda: len(d), "cache_name") - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 0', - 'cache:total{name="cache_name"} 0', - 'cache:size{name="cache_name"} 0', - 'cache:evicted_size{name="cache_name"} 0', - ]) - - metric.inc_misses() - d["key"] = "value" - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 0', - 'cache:total{name="cache_name"} 1', - 'cache:size{name="cache_name"} 1', - 'cache:evicted_size{name="cache_name"} 0', - ]) - - metric.inc_hits() - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 1', - 'cache:total{name="cache_name"} 2', - 'cache:size{name="cache_name"} 1', - 'cache:evicted_size{name="cache_name"} 0', - ]) - - metric.inc_evictions(2) - - self.assertEquals(metric.render(), [ - 'cache:hits{name="cache_name"} 1', - 'cache:total{name="cache_name"} 2', - 'cache:size{name="cache_name"} 1', - 'cache:evicted_size{name="cache_name"} 2', - ]) diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 64e07a8c93..8708c8a196 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from twisted.internet import defer, reactor -from tests import unittest - import tempfile from mock import Mock, NonCallableMock -from tests.utils import setup_test_homeserver -from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from twisted.internet import defer, reactor + from synapse.replication.tcp.client import ( - ReplicationClientHandler, ReplicationClientFactory, + ReplicationClientFactory, + ReplicationClientHandler, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory + +from tests import unittest +from tests.utils import setup_test_homeserver class BaseSlavedStoreTestCase(unittest.TestCase): diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py index da54d478ce..adf226404e 100644 --- a/tests/replication/slave/storage/test_account_data.py +++ b/tests/replication/slave/storage/test_account_data.py @@ -13,11 +13,11 @@ # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.account_data import SlavedAccountDataStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" TYPE = "my.type" @@ -37,10 +37,6 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 1}} - ) yield self.master_store.add_account_data_for_user( USER_ID, TYPE, {"a": 2} @@ -50,7 +46,3 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): "get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2} ) - yield self.check( - "get_global_account_data_by_type_for_users", - [TYPE, [USER_ID]], {USER_ID: {"a": 2}} - ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index cb058d3142..cea01d93eb 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.events import FrozenEvent, _EventInternalMetadata from synapse.events.snapshot import EventContext from synapse.replication.slave.storage.events import SlavedEventStore from synapse.storage.roommember import RoomsForUser -from twisted.internet import defer - +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" USER_ID_2 = "@bright:blue" diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index 6624fe4eea..e6d670cc1f 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import BaseSlavedStoreTestCase +from twisted.internet import defer from synapse.replication.slave.storage.receipts import SlavedReceiptsStore -from twisted.internet import defer +from ._base import BaseSlavedStoreTestCase USER_ID = "@feeling:blue" ROOM_ID = "!room:blue" diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index d7cea30260..eee99ca2e0 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -1,7 +1,11 @@ -from synapse.rest.client.transactions import HttpTransactionCache -from synapse.rest.client.transactions import CLEANUP_PERIOD_MS -from twisted.internet import defer from mock import Mock, call + +from twisted.internet import defer, reactor + +from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionCache +from synapse.util import Clock +from synapse.util.logcontext import LoggingContext + from tests import unittest from tests.utils import MockClock @@ -40,6 +44,78 @@ class HttpTransactionCacheTestCase(unittest.TestCase): cb.assert_called_once_with("some_arg", keyword="arg", changing_args=0) @defer.inlineCallbacks + def test_logcontexts_with_async_result(self): + @defer.inlineCallbacks + def cb(): + yield Clock(reactor).sleep(0) + defer.returnValue("yay") + + @defer.inlineCallbacks + def test(): + with LoggingContext("c") as c1: + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertIs(LoggingContext.current_context(), c1) + self.assertEqual(res, "yay") + + # run the test twice in parallel + d = defer.gatherResults([test(), test()]) + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + yield d + self.assertIs(LoggingContext.current_context(), LoggingContext.sentinel) + + @defer.inlineCallbacks + def test_does_not_cache_exceptions(self): + """Checks that, if the callback throws an exception, it is called again + for the next request. + """ + called = [False] + + def cb(): + if called[0]: + # return a valid result the second time + return defer.succeed(self.mock_http_response) + + called[0] = True + raise Exception("boo") + + with LoggingContext("test") as test_context: + try: + yield self.cache.fetch_or_execute(self.mock_key, cb) + except Exception as e: + self.assertEqual(e.message, "boo") + self.assertIs(LoggingContext.current_context(), test_context) + + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertEqual(res, self.mock_http_response) + self.assertIs(LoggingContext.current_context(), test_context) + + @defer.inlineCallbacks + def test_does_not_cache_failures(self): + """Checks that, if the callback returns a failure, it is called again + for the next request. + """ + called = [False] + + def cb(): + if called[0]: + # return a valid result the second time + return defer.succeed(self.mock_http_response) + + called[0] = True + return defer.fail(Exception("boo")) + + with LoggingContext("test") as test_context: + try: + yield self.cache.fetch_or_execute(self.mock_key, cb) + except Exception as e: + self.assertEqual(e.message, "boo") + self.assertIs(LoggingContext.current_context(), test_context) + + res = yield self.cache.fetch_or_execute(self.mock_key, cb) + self.assertEqual(res, self.mock_http_response) + self.assertIs(LoggingContext.current_context(), test_context) + + @defer.inlineCallbacks def test_cleans_up(self): cb = Mock( return_value=defer.succeed(self.mock_http_response) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index a8d09600bd..a5af36a99c 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -14,7 +14,7 @@ # limitations under the License. """ Tests REST events for /events paths.""" -from tests import unittest +from mock import Mock, NonCallableMock # twisted imports from twisted.internet import defer @@ -23,13 +23,11 @@ import synapse.rest.client.v1.events import synapse.rest.client.v1.register import synapse.rest.client.v1.room +from tests import unittest from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" @@ -148,11 +146,16 @@ class EventStreamPermissionsTestCase(RestTestCase): @defer.inlineCallbacks def test_stream_basic_permissions(self): - # invalid token, expect 403 + # invalid token, expect 401 + # note: this is in violation of the original v1 spec, which expected + # 403. However, since the v1 spec no longer exists and the v1 + # implementation is now part of the r0 implementation, the newer + # behaviour is used instead to be consistent with the r0 spec. + # see issue #2602 (code, response) = yield self.mock_resource.trigger_get( "/events?access_token=%s" % ("invalid" + self.token, ) ) - self.assertEquals(403, code, msg=str(response)) + self.assertEquals(401, code, msg=str(response)) # valid token, expect content (code, response) = yield self.mock_resource.trigger_get( diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index deac7f100c..d71cc8e0db 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -15,12 +15,15 @@ """Tests REST events for /profile paths.""" from mock import Mock + from twisted.internet import defer import synapse.types -from synapse.api.errors import SynapseError, AuthError +from synapse.api.errors import AuthError, SynapseError from synapse.rest.client.v1 import profile + from tests import unittest + from ....utils import MockHttpResource, setup_test_homeserver myid = "@1234ABCD:test" @@ -52,7 +55,7 @@ class ProfileTestCase(unittest.TestCase): def _get_user_by_req(request=None, allow_guest=False): return synapse.types.create_requester(myid) - hs.get_v1auth().get_user_by_req = _get_user_by_req + hs.get_auth().get_user_by_req = _get_user_by_req profile.register_servlets(hs, self.mock_resource) diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py index a6a4e2ffe0..f596acb85f 100644 --- a/tests/rest/client/v1/test_register.py +++ b/tests/rest/client/v1/test_register.py @@ -13,12 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.rest.client.v1.register import CreateUserRestServlet -from twisted.internet import defer +import json + from mock import Mock + +from twisted.internet import defer + +from synapse.rest.client.v1.register import CreateUserRestServlet + from tests import unittest from tests.utils import mock_getRawHeaders -import json class CreateUserServletTestCase(unittest.TestCase): diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7e8966a1a8..895dffa095 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -15,22 +15,21 @@ """Tests REST events for /rooms paths.""" +import json + +from mock import Mock, NonCallableMock +from six.moves.urllib import parse as urlparse + # twisted imports from twisted.internet import defer import synapse.rest.client.v1.room from synapse.api.constants import Membership - from synapse.types import UserID -import json -import urllib - from ....utils import MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - PATH_PREFIX = "/_matrix/client/api/v1" @@ -60,7 +59,7 @@ class RoomPermissionsTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -70,7 +69,7 @@ class RoomPermissionsTestCase(RestTestCase): synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource) - self.auth = hs.get_v1auth() + self.auth = hs.get_auth() # create some rooms under the name rmcreator_id self.uncreated_rmid = "!aa:test" @@ -425,7 +424,7 @@ class RoomsMemberListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -507,7 +506,7 @@ class RoomsCreateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -597,7 +596,7 @@ class RoomTopicTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -711,7 +710,7 @@ class RoomMemberStateTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -766,7 +765,7 @@ class RoomMemberStateTestCase(RestTestCase): @defer.inlineCallbacks def test_rooms_members_self(self): path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.user_id + urlparse.quote(self.room_id), self.user_id ) # valid join message (NOOP since we made the room) @@ -786,7 +785,7 @@ class RoomMemberStateTestCase(RestTestCase): def test_rooms_members_other(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), self.other_id ) # valid invite message @@ -802,7 +801,7 @@ class RoomMemberStateTestCase(RestTestCase): def test_rooms_members_other_custom_keys(self): self.other_id = "@zzsid1:red" path = "/rooms/%s/state/m.room.member/%s" % ( - urllib.quote(self.room_id), self.other_id + urlparse.quote(self.room_id), self.other_id ) # valid invite message with custom key @@ -843,7 +842,7 @@ class RoomMessagesTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -859,7 +858,7 @@ class RoomMessagesTestCase(RestTestCase): @defer.inlineCallbacks def test_invalid_puts(self): path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) # missing keys or invalid json (code, response) = yield self.mock_resource.trigger( "PUT", path, '{}' @@ -894,7 +893,7 @@ class RoomMessagesTestCase(RestTestCase): @defer.inlineCallbacks def test_rooms_messages_sent(self): path = "/rooms/%s/send/m.room.message/mid1" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) content = '{"body":"test","msgtype":{"type":"a"}}' (code, response) = yield self.mock_resource.trigger("PUT", path, content) @@ -911,7 +910,7 @@ class RoomMessagesTestCase(RestTestCase): # m.text message type path = "/rooms/%s/send/m.room.message/mid2" % ( - urllib.quote(self.room_id)) + urlparse.quote(self.room_id)) content = '{"body":"test2","msgtype":"m.text"}' (code, response) = yield self.mock_resource.trigger("PUT", path, content) self.assertEquals(200, code, msg=str(response)) @@ -945,7 +944,7 @@ class RoomInitialSyncTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) @@ -1017,7 +1016,7 @@ class RoomMessageListTestCase(RestTestCase): "token_id": 1, "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index 2ec4ecab5b..bddb3302e4 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -15,18 +15,17 @@ """Tests REST events for /rooms paths.""" +from mock import Mock, NonCallableMock + # twisted imports from twisted.internet import defer import synapse.rest.client.v1.room from synapse.types import UserID -from ....utils import MockHttpResource, MockClock, setup_test_homeserver +from ....utils import MockClock, MockHttpResource, setup_test_homeserver from .utils import RestTestCase -from mock import Mock, NonCallableMock - - PATH_PREFIX = "/_matrix/client/api/v1" @@ -68,7 +67,7 @@ class RoomTypingTestCase(RestTestCase): "is_guest": False, } - hs.get_v1auth().get_user_by_access_token = get_user_by_access_token + hs.get_auth().get_user_by_access_token = get_user_by_access_token def _insert_client_ip(*args, **kwargs): return defer.succeed(None) diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py index 3bb1dd003a..54d7ba380d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py @@ -13,16 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json +import time + # twisted imports from twisted.internet import defer -# trial imports -from tests import unittest - from synapse.api.constants import Membership -import json -import time +# trial imports +from tests import unittest class RestTestCase(unittest.TestCase): diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index 5170217d9e..f18a8a6027 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -13,16 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest - from mock import Mock -from ....utils import MockHttpResource, setup_test_homeserver +from twisted.internet import defer from synapse.types import UserID -from twisted.internet import defer +from tests import unittest +from ....utils import MockHttpResource, setup_test_homeserver PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py index 76b833e119..bb0b2f94ea 100644 --- a/tests/rest/client/v2_alpha/test_filter.py +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -15,16 +15,13 @@ from twisted.internet import defer -from tests import unittest - -from synapse.rest.client.v2_alpha import filter - -from synapse.api.errors import Codes - import synapse.types - +from synapse.api.errors import Codes +from synapse.rest.client.v2_alpha import filter from synapse.types import UserID +from tests import unittest + from ....utils import MockHttpResource, setup_test_homeserver PATH_PREFIX = "/_matrix/client/v2_alpha" diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py index 8aba456510..9b57a56070 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py @@ -1,12 +1,15 @@ +import json + +from mock import Mock + +from twisted.internet import defer from twisted.python import failure +from synapse.api.errors import InteractiveAuthIncompleteError, SynapseError from synapse.rest.client.v2_alpha.register import RegisterRestServlet -from synapse.api.errors import SynapseError, InteractiveAuthIncompleteError -from twisted.internet import defer -from mock import Mock + from tests import unittest from tests.utils import mock_getRawHeaders -import json class RegisterRestServletTestCase(unittest.TestCase): diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index eef38b6781..bf254a260d 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -14,21 +14,21 @@ # limitations under the License. -from twisted.internet import defer +import os +import shutil +import tempfile + +from mock import Mock + +from twisted.internet import defer, reactor from synapse.rest.media.v1._base import FileInfo -from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.filepath import MediaFilePaths +from synapse.rest.media.v1.media_storage import MediaStorage from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend -from mock import Mock - from tests import unittest -import os -import shutil -import tempfile - class MediaStorageTests(unittest.TestCase): def setUp(self): @@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase): self.secondary_base_path = os.path.join(self.test_dir, "secondary") hs = Mock() + hs.get_reactor = Mock(return_value=reactor) hs.config.media_store_path = self.primary_base_path storage_providers = [FileStorageProviderBackend( @@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase): self.filepaths = MediaFilePaths(self.primary_base_path) self.media_storage = MediaStorage( - self.primary_base_path, self.filepaths, storage_providers, + hs, self.primary_base_path, self.filepaths, storage_providers, ) def tearDown(self): diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000000..46223ccf05 --- /dev/null +++ b/tests/server.py @@ -0,0 +1,183 @@ +import json +from io import BytesIO + +from six import text_type + +import attr + +from twisted.internet import threads +from twisted.internet.defer import Deferred +from twisted.python.failure import Failure +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.site import SynapseRequest + +from tests.utils import setup_test_homeserver as _sth + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + result = attr.ib(factory=dict) + + @property + def json_body(self): + if not self.result: + raise Exception("No result yet.") + return json.loads(self.result["body"]) + + def writeHeaders(self, version, code, reason, headers): + self.result["version"] = version + self.result["code"] = code + self.result["reason"] = reason + self.result["headers"] = headers + + def write(self, content): + if "body" not in self.result: + self.result["body"] = b"" + + self.result["body"] += content + + def requestDone(self, _self): + self.result["done"] = True + + def getPeer(self): + return None + + def getHost(self): + return None + + @property + def transport(self): + return self + + +class FakeSite: + """ + A fake Twisted Web Site, with mocks of the extra things that + Synapse adds. + """ + + server_version_string = b"1" + site_tag = "test" + + @property + def access_logger(self): + class FakeLogger: + def info(self, *args, **kwargs): + pass + + return FakeLogger() + + +def make_request(method, path, content=b""): + """ + Make a web request using the given method and path, feed it the + content, and return the Request and the Channel underneath. + """ + + if isinstance(content, text_type): + content = content.encode('utf8') + + site = FakeSite() + channel = FakeChannel() + + req = SynapseRequest(site, channel) + req.process = lambda: b"" + req.content = BytesIO(content) + req.requestReceived(method, path, b"1.1") + + return req, channel + + +def wait_until_result(clock, channel, timeout=100): + """ + Wait until the channel has a result. + """ + clock.run() + x = 0 + + while not channel.result: + x += 1 + + if x > timeout: + raise Exception("Timed out waiting for request to finish.") + + clock.advance(0.1) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + return d + + +def setup_test_homeserver(*args, **kwargs): + """ + Set up a synchronous test server, driven by the reactor used by + the homeserver. + """ + d = _sth(*args, **kwargs).result + + # Make the thread pool synchronous. + clock = d.get_clock() + pool = d.get_db_pool() + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) + + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + + class ThreadPool: + """ + Threadless thread pool. + """ + def start(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + clock._reactor.callLater(0, d.callback, True) + return d + + clock.threadpool = ThreadPool() + pool.threadpool = ThreadPool() + return d diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index 3cfa21c9f8..6d6f00c5c5 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -14,15 +14,15 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer - from mock import Mock -from synapse.util.async import ObservableDeferred +from twisted.internet import defer +from synapse.util.async import ObservableDeferred from synapse.util.caches.descriptors import Cache, cached +from tests import unittest + class CacheTestCase(unittest.TestCase): diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 00825498b1..099861b27c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -12,21 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json +import os import tempfile -from synapse.config._base import ConfigError -from tests import unittest + +from mock import Mock + +import yaml + from twisted.internet import defer -from tests.utils import setup_test_homeserver from synapse.appservice import ApplicationService, ApplicationServiceState +from synapse.config._base import ConfigError from synapse.storage.appservice import ( - ApplicationServiceStore, ApplicationServiceTransactionStore + ApplicationServiceStore, + ApplicationServiceTransactionStore, ) -import json -import os -import yaml -from mock import Mock +from tests import unittest +from tests.utils import setup_test_homeserver class ApplicationServiceStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index 1286b4ce2d..ab1f310572 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -1,10 +1,10 @@ -from tests import unittest +from mock import Mock + from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class BackgroundUpdateTestCase(unittest.TestCase): diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index 0ac910e76f..1d1234ee39 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -14,18 +14,18 @@ # limitations under the License. -from tests import unittest -from twisted.internet import defer +from collections import OrderedDict from mock import Mock -from collections import OrderedDict +from twisted.internet import defer from synapse.server import HomeServer - from synapse.storage._base import SQLBaseStore from synapse.storage.engines import create_engine +from tests import unittest + class SQLBaseStoreTestCase(unittest.TestCase): """ Test the "simple" SQL generating methods in SQLBaseStore. """ diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f8725acea0..a54cc6bc32 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -16,6 +16,7 @@ from twisted.internet import defer import synapse.api.errors + import tests.unittest import tests.utils diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py index 95709cd50a..129ebaf343 100644 --- a/tests/storage/test_directory.py +++ b/tests/storage/test_directory.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.directory import DirectoryStore -from synapse.types import RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_event_push_actions.py b/tests/storage/test_event_push_actions.py index 575374c6a6..8430fc7ba6 100644 --- a/tests/storage/test_event_push_actions.py +++ b/tests/storage/test_event_push_actions.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from mock import Mock + from twisted.internet import defer import tests.unittest import tests.utils -from mock import Mock USER_ID = "@user:example.com" @@ -55,7 +56,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def _assert_counts(noitf_count, highlight_count): counts = yield self.store.runInteraction( "", self.store._get_unread_counts_by_pos_txn, - room_id, user_id, 0, 0 + room_id, user_id, 0 ) self.assertEquals( counts, @@ -86,7 +87,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): def _mark_read(stream, depth): return self.store.runInteraction( "", self.store._remove_old_push_actions_before_txn, - room_id, user_id, depth, stream + room_id, user_id, stream ) yield _assert_counts(0, 0) @@ -128,7 +129,6 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase): yield _rotate(10) yield _assert_counts(1, 1) - @tests.unittest.DEBUG @defer.inlineCallbacks def test_find_first_stream_ordering_after_ts(self): def add_event(so, ts): diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 0be790d8f8..3a3d002782 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -14,6 +14,7 @@ # limitations under the License. import signedjson.key + from twisted.internet import defer import tests.unittest diff --git a/tests/storage/test_presence.py b/tests/storage/test_presence.py index f5fcb611d4..3276b39504 100644 --- a/tests/storage/test_presence.py +++ b/tests/storage/test_presence.py @@ -14,13 +14,13 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.presence import PresenceStore from synapse.types import UserID -from tests.utils import setup_test_homeserver, MockClock +from tests import unittest +from tests.utils import MockClock, setup_test_homeserver class PresenceStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index 423710c9c1..2c95e5e95a 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.storage.profile import ProfileStore from synapse.types import UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py index 888ddfaddd..475ec900c4 100644 --- a/tests/storage/test_redaction.py +++ b/tests/storage/test_redaction.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RedactionTestCase(unittest.TestCase): diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 7c7b164ee6..7821ea3fa3 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -14,9 +14,9 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer +from tests import unittest from tests.utils import setup_test_homeserver @@ -42,9 +42,15 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.register(self.user_id, self.tokens[0], self.pwhash) self.assertEquals( - # TODO(paul): Surely this field should be 'user_id', not 'name' - # Additionally surely it shouldn't come in a 1-element list - {"name": self.user_id, "password_hash": self.pwhash, "is_guest": 0}, + { + # TODO(paul): Surely this field should be 'user_id', not 'name' + "name": self.user_id, + "password_hash": self.pwhash, + "is_guest": 0, + "consent_version": None, + "consent_server_notice_sent": None, + "appservice_id": None, + }, (yield self.store.get_user_by_id(self.user_id)) ) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index ef8a4d234f..ae8ae94b6d 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest from twisted.internet import defer from synapse.api.constants import EventTypes -from synapse.types import UserID, RoomID, RoomAlias +from synapse.types import RoomAlias, RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/storage/test_roommember.py b/tests/storage/test_roommember.py index 657b279e5d..c5fd54f67e 100644 --- a/tests/storage/test_roommember.py +++ b/tests/storage/test_roommember.py @@ -14,16 +14,16 @@ # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.types import UserID, RoomID +from synapse.types import RoomID, UserID +from tests import unittest from tests.utils import setup_test_homeserver -from mock import Mock - class RoomMemberStoreTestCase(unittest.TestCase): diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 0891308f25..23fad12bca 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -17,6 +17,7 @@ from twisted.internet import defer from synapse.storage import UserDirectoryStore from synapse.storage.roommember import ProfileInfo + from tests import unittest from tests.utils import setup_test_homeserver diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 010aeaee7e..04a88056f1 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -13,13 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest -from twisted.internet import defer - from mock import Mock, patch +from twisted.internet import defer + from synapse.util.distributor import Distributor -from synapse.util.async import run_on_reactor + +from . import unittest class DistributorTestCase(unittest.TestCase): @@ -95,7 +95,6 @@ class DistributorTestCase(unittest.TestCase): @defer.inlineCallbacks def observer(): - yield run_on_reactor() raise MyException("Oopsie") self.dist.observe("whail", observer) diff --git a/tests/test_dns.py b/tests/test_dns.py index af607d626f..b647d92697 100644 --- a/tests/test_dns.py +++ b/tests/test_dns.py @@ -13,16 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest +from mock import Mock + from twisted.internet import defer from twisted.names import dns, error -from mock import Mock - from synapse.http.endpoint import resolve_service from tests.utils import MockClock +from . import unittest + @unittest.DEBUG class DnsTestCase(unittest.TestCase): @@ -62,7 +63,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock = Mock() dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 0 @@ -87,7 +88,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock = Mock(spec_set=['lookupService']) dns_client_mock.lookupService = Mock(spec_set=[]) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" entry = Mock(spec_set=["expires"]) entry.expires = 999999999 @@ -111,7 +112,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSServerError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" cache = {} @@ -126,7 +127,7 @@ class DnsTestCase(unittest.TestCase): dns_client_mock.lookupService.return_value = defer.fail(error.DNSNameError()) - service_name = "test_service.examle.com" + service_name = "test_service.example.com" cache = {} diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py new file mode 100644 index 0000000000..06112430e5 --- /dev/null +++ b/tests/test_event_auth.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from synapse import event_auth +from synapse.api.errors import AuthError +from synapse.events import FrozenEvent + + +class EventAuthTestCase(unittest.TestCase): + def test_random_users_cannot_send_state_before_first_pl(self): + """ + Check that, before the first PL lands, the creator is the only user + that can send a state event. + """ + creator = "@creator:example.com" + joiner = "@joiner:example.com" + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.member", joiner): _join_event(joiner), + } + + # creator should be able to send state + event_auth.check( + _random_state_event(creator), auth_events, + do_sig_check=False, + ) + + # joiner should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(joiner), + auth_events, + do_sig_check=False, + ), + + def test_state_default_level(self): + """ + Check that users above the state_default level can send state and + those below cannot + """ + creator = "@creator:example.com" + pleb = "@joiner:example.com" + king = "@joiner2:example.com" + + auth_events = { + ("m.room.create", ""): _create_event(creator), + ("m.room.member", creator): _join_event(creator), + ("m.room.power_levels", ""): _power_levels_event(creator, { + "state_default": "30", + "users": { + pleb: "29", + king: "30", + }, + }), + ("m.room.member", pleb): _join_event(pleb), + ("m.room.member", king): _join_event(king), + } + + # pleb should not be able to send state + self.assertRaises( + AuthError, + event_auth.check, + _random_state_event(pleb), + auth_events, + do_sig_check=False, + ), + + # king should be able to send state + event_auth.check( + _random_state_event(king), auth_events, + do_sig_check=False, + ) + + +# helpers for making events + +TEST_ROOM_ID = "!test:room" + + +def _create_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.create", + "sender": user_id, + "content": { + "creator": user_id, + }, + }) + + +def _join_event(user_id): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.member", + "sender": user_id, + "state_key": user_id, + "content": { + "membership": "join", + }, + }) + + +def _power_levels_event(sender, content): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "m.room.power_levels", + "sender": sender, + "state_key": "", + "content": content, + }) + + +def _random_state_event(sender): + return FrozenEvent({ + "room_id": TEST_ROOM_ID, + "event_id": _get_event_id(), + "type": "test.state", + "sender": sender, + "state_key": "", + "content": { + "membership": "join", + }, + }) + + +event_count = 0 + + +def _get_event_id(): + global event_count + c = event_count + event_count += 1 + return "!%i:example.com" % (c, ) diff --git a/tests/test_federation.py b/tests/test_federation.py new file mode 100644 index 0000000000..159a136971 --- /dev/null +++ b/tests/test_federation.py @@ -0,0 +1,243 @@ + +from mock import Mock + +from twisted.internet.defer import maybeDeferred, succeed + +from synapse.events import FrozenEvent +from synapse.types import Requester, UserID +from synapse.util import Clock + +from tests import unittest +from tests.server import ThreadedMemoryReactorClock, setup_test_homeserver + + +class MessageAcceptTests(unittest.TestCase): + def setUp(self): + + self.http_client = Mock() + self.reactor = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + ) + + user_id = UserID("us", "test") + our_user = Requester(user_id, None, False, None, None) + room_creator = self.homeserver.get_room_creation_handler() + room = room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) + self.reactor.advance(0.1) + self.room_id = self.successResultOf(room)["room_id"] + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + join_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$join:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "prev_state": [(most_recent, {})], + "prev_events": [(most_recent, {})], + } + ) + + self.handler = self.homeserver.get_handlers().federation_handler + self.handler.do_auth = lambda *a, **b: succeed(True) + self.client = self.homeserver.get_federation_client() + self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( + pdus + ) + + # Send the join, it should return None (which is not an error) + d = self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + # Make sure we actually joined the room + self.assertEqual( + self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0], + "$join:test.serv", + ) + + def test_cant_hide_direct_ancestors(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + # If it asks us for new missing events, give them NOTHING + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return {"events": []} + + self.http_client.post_json = post_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Now lie about an event + lying_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("two:test.serv", {}), (most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) + + # Step the reactor, so the database fetches come back + self.reactor.advance(1) + + # on_receive_pdu should throw an error + failure = self.failureResultOf(d) + self.assertEqual( + failure.value.args[0], + ( + "ERROR 403: Your server isn't divulging details about prev_events " + "referenced in this event." + ), + ) + + # Make sure the invalid event isn't there + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + @unittest.DEBUG + def test_cant_hide_past_history(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return { + "events": [ + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "three:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("four:test.serv", {})], + } + ] + } + + self.http_client.post_json = post_json + + def get_json(destination, path, args, headers=None): + if path.startswith("/_matrix/federation/v1/state_ids/"): + d = self.successResultOf( + self.homeserver.datastore.get_state_ids_for_event("one:test.serv") + ) + + return succeed( + { + "pdu_ids": [ + y + for x, y in d.items() + if x == ("m.room.member", "@us:test") + ], + "auth_chain_ids": d.values(), + } + ) + + self.http_client.get_json = get_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Make a good event + good_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [(most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", good_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + bad_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "two:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", bad_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") + + state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) + self.reactor.advance(1) + self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/test_preview.py b/tests/test_preview.py index 5bd36c74aa..446843367e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -13,12 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unittest - from synapse.rest.media.v1.preview_url_resource import ( - summarize_paragraphs, decode_and_calc_og + decode_and_calc_og, + summarize_paragraphs, ) +from . import unittest + class PreviewTestCase(unittest.TestCase): diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000000..4192013f6d --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,129 @@ +import json +import re + +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import JsonResource +from synapse.util import Clock + +from tests import unittest +from tests.server import make_request, setup_test_homeserver + + +class JsonResourceTests(unittest.TestCase): + def setUp(self): + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.reactor + ) + + def test_handler_for_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + got_kwargs = {} + + def _callback(request, **kwargs): + got_kwargs.update(kwargs) + return (200, kwargs) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback) + + request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") + request.render(res) + + self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) + self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + + def test_callback_direct_exception(self): + """ + If the web callback raises an uncaught exception, it will be translated + into a 500. + """ + + def _callback(request, **kwargs): + raise Exception("boo") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'500') + + def test_callback_indirect_exception(self): + """ + If the web callback raises an uncaught exception in a Deferred, it will + be translated into a 500. + """ + + def _throw(*args): + raise Exception("boo") + + def _callback(request, **kwargs): + d = Deferred() + d.addCallback(_throw) + self.reactor.callLater(1, d.callback, True) + return d + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + # No error has been raised yet + self.assertTrue("code" not in channel.result) + + # Advance time, now there's an error + self.reactor.advance(1) + self.assertEqual(channel.result["code"], b'500') + + def test_callback_synapseerror(self): + """ + If the web callback raises a SynapseError, it returns the appropriate + status code and message set in it. + """ + + def _callback(request, **kwargs): + raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'403') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Forbidden!!one!") + self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + + def test_no_handler(self): + """ + If there is no handler to process the request, Synapse will return 400. + """ + + def _callback(request, **kwargs): + """ + Not ever actually called! + """ + self.fail("shouldn't ever get here") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foobar") + request.render(res) + + self.assertEqual(channel.result["code"], b'400') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Unrecognized request") + self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") diff --git a/tests/test_state.py b/tests/test_state.py index a5c5e55951..c0f2d1152d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -13,18 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest +from mock import Mock + from twisted.internet import defer -from synapse.events import FrozenEvent from synapse.api.auth import Auth from synapse.api.constants import EventTypes, Membership +from synapse.events import FrozenEvent from synapse.state import StateHandler, StateResolutionHandler -from .utils import MockClock - -from mock import Mock +from tests import unittest +from .utils import MockClock _next_event_id = 1000 @@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase): } ) + power_levels = create_event( + type=EventTypes.PowerLevels, state_key="", + content={"users": { + "@foo:bar": "100", + "@user_id:example.com": "100", + }} + ) + creation = create_event( type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"} @@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] @@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_2[2].event_id, context.current_state_ids[("test1", "1")] + old_state_2[3].event_id, context.current_state_ids[("test1", "1")] ) # Reverse the depth to make sure we are actually using the depths @@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase): old_state_1 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=2), ] old_state_2 = [ creation, + power_levels, member_event, create_event(type="test1", state_key="1", depth=1), ] @@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase): ) self.assertEqual( - old_state_1[2].event_id, context.current_state_ids[("test1", "1")] + old_state_1[3].event_id, context.current_state_ids[("test1", "1")] ) def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2, diff --git a/tests/test_test_utils.py b/tests/test_test_utils.py index d28bb726bb..bc97c12245 100644 --- a/tests/test_test_utils.py +++ b/tests/test_test_utils.py @@ -14,7 +14,6 @@ # limitations under the License. from tests import unittest - from tests.utils import MockClock diff --git a/tests/test_types.py b/tests/test_types.py index 115def2287..729bd676c1 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -13,11 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests import unittest - from synapse.api.errors import SynapseError from synapse.server import HomeServer -from synapse.types import UserID, RoomAlias, GroupID +from synapse.types import GroupID, RoomAlias, UserID + +from tests import unittest mock_homeserver = HomeServer(hostname="my.domain") diff --git a/tests/unittest.py b/tests/unittest.py index 7b478c4294..b25f2db5d5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -12,23 +12,40 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import logging + import twisted +import twisted.logger from twisted.trial import unittest -import logging +from synapse.util.logcontext import LoggingContextFilter + +# Set up putting Synapse's logs into Trial's. +rootLogger = logging.getLogger() + +log_format = ( + "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s" +) + + +class ToTwistedHandler(logging.Handler): + tx_log = twisted.logger.Logger() + + def emit(self, record): + log_entry = self.format(record) + log_level = record.levelname.lower().replace('warning', 'warn') + self.tx_log.emit( + twisted.logger.LogLevel.levelWithName(log_level), + log_entry.replace("{", r"(").replace("}", r")"), + ) -# logging doesn't have a "don't log anything at all EVARRRR setting, -# but since the highest value is 50, 1000000 should do ;) -NEVER = 1000000 -handler = logging.StreamHandler() -handler.setFormatter(logging.Formatter( - "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]" -)) -logging.getLogger().addHandler(handler) -logging.getLogger().setLevel(NEVER) -logging.getLogger("synapse.storage.SQL").setLevel(NEVER) -logging.getLogger("synapse.storage.txn").setLevel(NEVER) +handler = ToTwistedHandler() +formatter = logging.Formatter(log_format) +handler.setFormatter(formatter) +handler.addFilter(LoggingContextFilter(request="")) +rootLogger.addHandler(handler) def around(target): @@ -61,7 +78,7 @@ class TestCase(unittest.TestCase): method = getattr(self, methodName) - level = getattr(method, "loglevel", getattr(self, "loglevel", NEVER)) + level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR)) @around(self) def setUp(orig): diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 2516fe40f4..8176a7dabd 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,20 +13,28 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from functools import partial import logging +from functools import partial import mock + +from twisted.internet import defer, reactor + from synapse.api.errors import SynapseError -from synapse.util import async from synapse.util import logcontext -from twisted.internet import defer from synapse.util.caches import descriptors + from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + class CacheTestCase(unittest.TestCase): def test_invalidate_all(self): cache = descriptors.Cache("testcache") @@ -195,7 +203,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): - yield async.run_on_reactor() + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -205,7 +214,12 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index bc92f85fa6..26f2fa5800 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.caches.dictionary_cache import DictionaryCache +from tests import unittest + class DictCacheTestCase(unittest.TestCase): @@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase): seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key) self.assertEqual(test_value, c.value) @@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_hit_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) @@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase): test_value = { "test": "test_simple_cache_miss_partial" } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) @@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase): "test2": "test_simple_cache_hit_miss_partial2", "test3": "test_simple_cache_hit_miss_partial3", } - self.cache.update(seq, key, test_value, full=True) + self.cache.update(seq, key, test_value) c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) @@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase): test_value_1 = { "test": "test_simple_cache_hit_miss_partial", } - self.cache.update(seq, key, test_value_1, full=False) + self.cache.update(seq, key, test_value_1, fetched_keys=set("test")) seq = self.cache.sequence test_value_2 = { "test2": "test_simple_cache_hit_miss_partial2", } - self.cache.update(seq, key, test_value_2, full=False) + self.cache.update(seq, key, test_value_2, fetched_keys=set("test2")) c = self.cache.get(key) self.assertEqual( diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 31d24adb8b..d12b5e838b 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock +from .. import unittest + class ExpiringCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index d6e1082779..7ce5f8c258 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -14,15 +14,16 @@ # limitations under the License. -from twisted.internet import defer, reactor +import threading + from mock import NonCallableMock +from six import StringIO + +from twisted.internet import defer, reactor from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest -from six import StringIO - -import threading class FileConsumerTests(unittest.TestCase): @@ -30,7 +31,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_pull_consumer(self): string_file = StringIO() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = DummyPullProducer() @@ -54,7 +55,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_consumer(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=[]) @@ -80,7 +81,7 @@ class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks def test_push_producer_feedback(self): string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file) + consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py index 9c795d9fdb..a5a767b1ff 100644 --- a/tests/util/test_limiter.py +++ b/tests/util/test_limiter.py @@ -14,12 +14,12 @@ # limitations under the License. -from tests import unittest - from twisted.internet import defer from synapse.util.async import Limiter +from tests import unittest + class LimiterTestCase(unittest.TestCase): diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py index 4865eb4bc6..c95907b32c 100644 --- a/tests/util/test_linearizer.py +++ b/tests/util/test_linearizer.py @@ -12,13 +12,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util import async, logcontext -from tests import unittest -from twisted.internet import defer +from six.moves import range + +from twisted.internet import defer, reactor +from synapse.util import Clock, logcontext from synapse.util.async import Linearizer -from six.moves import range + +from tests import unittest class LinearizerTestCase(unittest.TestCase): @@ -53,7 +55,7 @@ class LinearizerTestCase(unittest.TestCase): self.assertEqual( logcontext.LoggingContext.current_context(), lc) if sleep: - yield async.sleep(0) + yield Clock(reactor).sleep(0) self.assertEqual( logcontext.LoggingContext.current_context(), lc) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 4850722bc5..c54001f7a4 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,12 +1,11 @@ import twisted.python.failure -from twisted.internet import defer -from twisted.internet import reactor -from .. import unittest +from twisted.internet import defer, reactor -from synapse.util.async import sleep -from synapse.util import logcontext +from synapse.util import Clock, logcontext from synapse.util.logcontext import LoggingContext +from .. import unittest + class LoggingContextTestCase(unittest.TestCase): @@ -22,38 +21,44 @@ class LoggingContextTestCase(unittest.TestCase): @defer.inlineCallbacks def test_sleep(self): + clock = Clock(reactor) + @defer.inlineCallbacks def competing_callback(): with LoggingContext() as competing_context: competing_context.request = "competing" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("competing") reactor.callLater(0, competing_callback) with LoggingContext() as context_one: context_one.request = "one" - yield sleep(0) + yield clock.sleep(0) self._check_test_key("one") - def _test_preserve_fn(self, function): + def _test_run_in_background(self, function): sentinel_context = LoggingContext.current_context() callback_completed = [False] - @defer.inlineCallbacks - def cb(): + def test(): context_one.request = "one" - yield function() - self._check_test_key("one") + d = function() - callback_completed[0] = True + def cb(res): + self._check_test_key("one") + callback_completed[0] = True + return res + d.addCallback(cb) + + return d with LoggingContext() as context_one: context_one.request = "one" # fire off function, but don't wait on it. - logcontext.preserve_fn(cb)() + logcontext.run_in_background(test) self._check_test_key("one") @@ -80,20 +85,30 @@ class LoggingContextTestCase(unittest.TestCase): # test is done once d2 finishes return d2 - def test_preserve_fn_with_blocking_fn(self): + def test_run_in_background_with_blocking_fn(self): @defer.inlineCallbacks def blocking_function(): - yield sleep(0) + yield Clock(reactor).sleep(0) - return self._test_preserve_fn(blocking_function) + return self._test_run_in_background(blocking_function) - def test_preserve_fn_with_non_blocking_fn(self): + def test_run_in_background_with_non_blocking_fn(self): @defer.inlineCallbacks def nonblocking_function(): with logcontext.PreserveLoggingContext(): yield defer.succeed(None) - return self._test_preserve_fn(nonblocking_function) + return self._test_run_in_background(nonblocking_function) + + def test_run_in_background_with_chained_deferred(self): + # a function which returns a deferred which looks like it has been + # called, but is actually paused + def testfunc(): + return logcontext.make_deferred_yieldable( + _chained_deferred_function() + ) + + return self._test_run_in_background(testfunc) @defer.inlineCallbacks def test_make_deferred_yieldable(self): @@ -119,6 +134,22 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") @defer.inlineCallbacks + def test_make_deferred_yieldable_with_chained_deferreds(self): + sentinel_context = LoggingContext.current_context() + + with LoggingContext() as context_one: + context_one.request = "one" + + d1 = logcontext.make_deferred_yieldable(_chained_deferred_function()) + # make sure that the context was reset by make_deferred_yieldable + self.assertIs(LoggingContext.current_context(), sentinel_context) + + yield d1 + + # now it should be restored + self._check_test_key("one") + + @defer.inlineCallbacks def test_make_deferred_yieldable_on_non_deferred(self): """Check that make_deferred_yieldable does the right thing when its argument isn't actually a deferred""" @@ -132,3 +163,17 @@ class LoggingContextTestCase(unittest.TestCase): r = yield d1 self.assertEqual(r, "bum") self._check_test_key("one") + + +# a function which returns a deferred which has been "called", but +# which had a function which returned another incomplete deferred on +# its callback list, so won't yet call any other new callbacks. +def _chained_deferred_function(): + d = defer.succeed(None) + + def cb(res): + d2 = defer.Deferred() + reactor.callLater(0, d2.callback, res) + return d2 + d.addCallback(cb) + return d diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py new file mode 100644 index 0000000000..297aebbfbe --- /dev/null +++ b/tests/util/test_logformatter.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +from synapse.util.logformatter import LogFormatter + +from tests import unittest + + +class TestException(Exception): + pass + + +class LogFormatterTestCase(unittest.TestCase): + def test_formatter(self): + formatter = LogFormatter() + + try: + raise TestException("testytest") + except TestException: + ei = sys.exc_info() + + output = formatter.formatException(ei) + + # check the output looks vaguely sane + self.assertIn("testytest", output) + self.assertIn("Capture point", output) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index dfb78cb8bd..9b36ef4482 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -14,12 +14,12 @@ # limitations under the License. -from .. import unittest +from mock import Mock from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache -from mock import Mock +from .. import unittest class LruCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 1d745ae1a7..24194e3b25 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -14,10 +14,10 @@ # limitations under the License. -from tests import unittest - from synapse.util.async import ReadWriteLock +from tests import unittest + class ReadWriteLockTestCase(unittest.TestCase): diff --git a/tests/util/test_snapshot_cache.py b/tests/util/test_snapshot_cache.py index d3a8630c2f..0f5b32fcc0 100644 --- a/tests/util/test_snapshot_cache.py +++ b/tests/util/test_snapshot_cache.py @@ -14,10 +14,11 @@ # limitations under the License. -from .. import unittest +from twisted.internet.defer import Deferred from synapse.util.caches.snapshot_cache import SnapshotCache -from twisted.internet.defer import Deferred + +from .. import unittest class SnapshotCacheTestCase(unittest.TestCase): diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py new file mode 100644 index 0000000000..e3897c0d19 --- /dev/null +++ b/tests/util/test_stream_change_cache.py @@ -0,0 +1,199 @@ +from mock import patch + +from synapse.util.caches.stream_change_cache import StreamChangeCache + +from tests import unittest + + +class StreamChangeCacheTests(unittest.TestCase): + """ + Tests for StreamChangeCache. + """ + + def test_prefilled_cache(self): + """ + Providing a prefilled cache to StreamChangeCache will result in a cache + with the prefilled-cache entered in. + """ + cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) + self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) + + def test_has_entity_changed(self): + """ + StreamChangeCache.entity_has_changed will mark entities as changed, and + has_entity_changed will observe the changed entities. + """ + cache = StreamChangeCache("#test", 3) + + cache.entity_has_changed("user@foo.com", 6) + cache.entity_has_changed("bar@baz.net", 7) + + # If it's been changed after that stream position, return True + self.assertTrue(cache.has_entity_changed("user@foo.com", 4)) + self.assertTrue(cache.has_entity_changed("bar@baz.net", 4)) + + # If it's been changed at that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 6)) + + # If there's no changes after that stream position, return False + self.assertFalse(cache.has_entity_changed("user@foo.com", 7)) + + # If the entity does not exist, return False. + self.assertFalse(cache.has_entity_changed("not@here.website", 7)) + + # If we request before the stream cache's earliest known position, + # return True, whether it's a known entity or not. + self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) + self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + + @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) + def test_has_entity_changed_pops_off_start(self): + """ + StreamChangeCache.entity_has_changed will respect the max size and + purge the oldest items upon reaching that max size. + """ + cache = StreamChangeCache("#test", 1, max_size=2) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # The cache is at the max size, 2 + self.assertEqual(len(cache._cache), 2) + + # The oldest item has been popped off + self.assertTrue("user@foo.com" not in cache._entity_to_key) + + # If we update an existing entity, it keeps the two existing entities + cache.entity_has_changed("bar@baz.net", 5) + self.assertEqual( + set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key) + ) + + def test_get_all_entities_changed(self): + """ + StreamChangeCache.get_all_entities_changed will return all changed + entities since the given position. If the position is before the start + of the known stream, it returns None instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + self.assertEqual( + cache.get_all_entities_changed(1), + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], + ) + self.assertEqual( + cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"] + ) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertEqual(cache.get_all_entities_changed(0), None) + + def test_has_any_entity_changed(self): + """ + StreamChangeCache.has_any_entity_changed will return True if any + entities have been changed since the provided stream position, and + False if they have not. If the cache has entries and the provided + stream position is before it, it will return True, otherwise False if + the cache has no entries. + """ + cache = StreamChangeCache("#test", 1) + + # With no entities, it returns False for the past, present, and future. + self.assertFalse(cache.has_any_entity_changed(0)) + self.assertFalse(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + + # We add an entity + cache.entity_has_changed("user@foo.com", 2) + + # With an entity, it returns True for the past, the stream start + # position, and False for the stream position the entity was changed + # on and ones after it. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) + self.assertFalse(cache.has_any_entity_changed(2)) + self.assertFalse(cache.has_any_entity_changed(3)) + + def test_get_entities_changed(self): + """ + StreamChangeCache.get_entities_changed will return the entities in the + given list that have changed since the provided stream ID. If the + stream position is earlier than the earliest known position, it will + return all of the entities queried for. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Query all the entries, but mid-way through the stream. We should only + # get the ones after that point. + self.assertEqual( + cache.get_entities_changed( + ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 + ), + set(["bar@baz.net", "user@elsewhere.org"]), + ) + + # Query all the entries mid-way through the stream, but include one + # that doesn't exist in it. We should get back the one that doesn't + # exist, too. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=2, + ), + set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]), + ) + + # Query all the entries, but before the first known point. We will get + # all the entries we queried for, including ones that don't exist. + self.assertEqual( + cache.get_entities_changed( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ], + stream_pos=0, + ), + set( + [ + "user@foo.com", + "bar@baz.net", + "user@elsewhere.org", + "not@here.website", + ] + ), + ) + + def test_max_pos(self): + """ + StreamChangeCache.get_max_pos_of_last_change will return the most + recent point where the entity could have changed. If the entity is not + known, the stream start is provided instead. + """ + cache = StreamChangeCache("#test", 1) + + cache.entity_has_changed("user@foo.com", 2) + cache.entity_has_changed("bar@baz.net", 3) + cache.entity_has_changed("user@elsewhere.org", 4) + + # Known entities will return the point where they were changed. + self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2) + self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) + self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) + + # Unknown entities will return the stream start position. + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 7ab578a185..a5f2261208 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -14,10 +14,10 @@ # limitations under the License. -from .. import unittest - from synapse.util.caches.treecache import TreeCache +from .. import unittest + class TreeCacheTestCase(unittest.TestCase): def test_get_set_onelevel(self): diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index c44567e52e..03201a4d9b 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .. import unittest - from synapse.util.wheel_timer import WheelTimer +from .. import unittest + class WheelTimerTestCase(unittest.TestCase): def test_single_insert_fetch(self): @@ -33,7 +33,7 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(170), []) - def test_mutli_insert(self): + def test_multi_insert(self): wheel = WheelTimer(bucket_size=5) obj1 = object() @@ -58,7 +58,7 @@ class WheelTimerTestCase(unittest.TestCase): wheel.insert(100, obj, 50) self.assertListEqual(wheel.fetch(120), [obj]) - def test_insert_past_mutli(self): + def test_insert_past_multi(self): wheel = WheelTimer(bucket_size=5) obj1 = object() diff --git a/tests/utils.py b/tests/utils.py index 0cd9f7eeee..6adbdbfca1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -15,10 +15,10 @@ import hashlib from inspect import getcallargs -import urllib -import urlparse from mock import Mock, patch +from six.moves.urllib import parse as urlparse + from twisted.internet import defer, reactor from synapse.api.errors import CodeMessageException, cs_error @@ -38,11 +38,15 @@ USE_POSTGRES_FOR_TESTS = False @defer.inlineCallbacks -def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): +def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None, + **kargs): """Setup a homeserver suitable for running tests against. Keyword arguments are passed to the Homeserver constructor. If no datastore is supplied a datastore backed by an in-memory sqlite db will be given to the HS. """ + if reactor is None: + from twisted.internet import reactor + if config is None: config = Mock() config.signing_key = [MockKey()] @@ -64,6 +68,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): config.federation_rc_concurrent = 10 config.filter_timeline_limit = 5000 config.user_directory_search_all_users = False + config.user_consent_server_notice_content = None + config.block_events_without_consent_error = None # disable user directory updates, because they get done in the # background, which upsets the test runner. @@ -109,6 +115,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs): database_engine=db_engine, room_list_handler=object(), tls_server_context_factory=Mock(), + reactor=reactor, **kargs ) db_conn = hs.get_db_conn() @@ -238,7 +245,7 @@ class MockHttpResource(HttpServer): if matcher: try: args = [ - urllib.unquote(u).decode("UTF-8") + urlparse.unquote(u).decode("UTF-8") for u in matcher.groups() ] |