diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/api/test_auth.py | 51 | ||||
-rw-r--r-- | tests/federation/__init__.py | 0 | ||||
-rw-r--r-- | tests/federation/test_federation_server.py | 57 | ||||
-rw-r--r-- | tests/http/__init__.py | 0 | ||||
-rw-r--r-- | tests/http/test_endpoint.py | 55 | ||||
-rw-r--r-- | tests/server.py | 181 | ||||
-rw-r--r-- | tests/test_federation.py | 243 | ||||
-rw-r--r-- | tests/test_server.py | 128 | ||||
-rw-r--r-- | tests/unittest.py | 5 | ||||
-rw-r--r-- | tests/util/caches/test_descriptors.py | 17 |
10 files changed, 731 insertions, 6 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 4575dd9834..aec3b62897 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -86,16 +86,53 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token(self): - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = yield self.auth.get_user_by_req(request) + self.assertEquals(requester.user.to_string(), self.test_user) + + @defer.inlineCallbacks + def test_get_user_by_req_appservice_valid_token_good_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "192.168.10.10" request.args["access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = yield self.auth.get_user_by_req(request) self.assertEquals(requester.user.to_string(), self.test_user) + def test_get_user_by_req_appservice_valid_token_bad_ip(self): + from netaddr import IPSet + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=IPSet(["192.168/16"]), + ) + self.store.get_app_service_by_token = Mock(return_value=app_service) + self.store.get_user_by_access_token = Mock(return_value=None) + + request = Mock(args={}) + request.getClientIP.return_value = "131.111.8.42" + request.args["access_token"] = [self.test_token] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + d = self.auth.get_user_by_req(request) + self.failureResultOf(d, AuthError) + def test_get_user_by_req_appservice_bad_token(self): self.store.get_app_service_by_token = Mock(return_value=None) self.store.get_user_by_access_token = Mock(return_value=None) @@ -119,12 +156,16 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_user_by_req_appservice_valid_token_valid_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -133,12 +174,16 @@ class AuthTestCase(unittest.TestCase): def test_get_user_by_req_appservice_valid_token_bad_user_id(self): masquerading_user_id = "@doppelganger:matrix.org" - app_service = Mock(token="foobar", url="a_url", sender=self.test_user) + app_service = Mock( + token="foobar", url="a_url", sender=self.test_user, + ip_range_whitelist=None, + ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) self.store.get_user_by_access_token = Mock(return_value=None) request = Mock(args={}) + request.getClientIP.return_value = "127.0.0.1" request.args["access_token"] = [self.test_token] request.args["user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() diff --git a/tests/federation/__init__.py b/tests/federation/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/federation/__init__.py diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py new file mode 100644 index 0000000000..4e8dc8fea0 --- /dev/null +++ b/tests/federation/test_federation_server.py @@ -0,0 +1,57 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.events import FrozenEvent +from synapse.federation.federation_server import server_matches_acl_event +from tests import unittest + + +@unittest.DEBUG +class ServerACLsTestCase(unittest.TestCase): + def test_blacklisted_server(self): + e = _create_acl_event({ + "allow": ["*"], + "deny": ["evil.com"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("evil.com", e)) + self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + + self.assertTrue(server_matches_acl_event("evil.com.au", e)) + self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + + def test_block_ip_literals(self): + e = _create_acl_event({ + "allow_ip_literals": False, + "allow": ["*"], + }) + logging.info("ACL event: %s", e.content) + + self.assertFalse(server_matches_acl_event("1.2.3.4", e)) + self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) + self.assertFalse(server_matches_acl_event("[1:2::]", e)) + self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + + +def _create_acl_event(content): + return FrozenEvent({ + "room_id": "!a:b", + "event_id": "$a:b", + "type": "m.room.server_acls", + "sender": "@a:b", + "content": content + }) diff --git a/tests/http/__init__.py b/tests/http/__init__.py new file mode 100644 index 0000000000..e69de29bb2 --- /dev/null +++ b/tests/http/__init__.py diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py new file mode 100644 index 0000000000..b8a48d20a4 --- /dev/null +++ b/tests/http/test_endpoint.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.http.endpoint import ( + parse_server_name, + parse_and_validate_server_name, +) +from tests import unittest + + +class ServerNameTestCase(unittest.TestCase): + def test_parse_server_name(self): + test_data = { + 'localhost': ('localhost', None), + 'my-example.com:1234': ('my-example.com', 1234), + '1.2.3.4': ('1.2.3.4', None), + '[0abc:1def::1234]': ('[0abc:1def::1234]', None), + '1.2.3.4:1': ('1.2.3.4', 1), + '[0abc:1def::1234]:8080': ('[0abc:1def::1234]', 8080), + } + + for i, o in test_data.items(): + self.assertEqual(parse_server_name(i), o) + + def test_validate_bad_server_names(self): + test_data = [ + "", # empty + "localhost:http", # non-numeric port + "1234]", # smells like ipv6 literal but isn't + "[1234", + "underscore_.com", + "percent%65.com", + "1234:5678:80", # too many colons + ] + for i in test_data: + try: + parse_and_validate_server_name(i) + self.fail( + "Expected parse_and_validate_server_name('%s') to throw" % ( + i, + ), + ) + except ValueError: + pass diff --git a/tests/server.py b/tests/server.py new file mode 100644 index 0000000000..73069dff52 --- /dev/null +++ b/tests/server.py @@ -0,0 +1,181 @@ +from io import BytesIO + +import attr +import json +from six import text_type + +from twisted.python.failure import Failure +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.http.site import SynapseRequest +from twisted.internet import threads +from tests.utils import setup_test_homeserver as _sth + + +@attr.s +class FakeChannel(object): + """ + A fake Twisted Web Channel (the part that interfaces with the + wire). + """ + + result = attr.ib(factory=dict) + + @property + def json_body(self): + if not self.result: + raise Exception("No result yet.") + return json.loads(self.result["body"]) + + def writeHeaders(self, version, code, reason, headers): + self.result["version"] = version + self.result["code"] = code + self.result["reason"] = reason + self.result["headers"] = headers + + def write(self, content): + if "body" not in self.result: + self.result["body"] = b"" + + self.result["body"] += content + + def requestDone(self, _self): + self.result["done"] = True + + def getPeer(self): + return None + + def getHost(self): + return None + + @property + def transport(self): + return self + + +class FakeSite: + """ + A fake Twisted Web Site, with mocks of the extra things that + Synapse adds. + """ + + server_version_string = b"1" + site_tag = "test" + + @property + def access_logger(self): + class FakeLogger: + def info(self, *args, **kwargs): + pass + + return FakeLogger() + + +def make_request(method, path, content=b""): + """ + Make a web request using the given method and path, feed it the + content, and return the Request and the Channel underneath. + """ + + if isinstance(content, text_type): + content = content.encode('utf8') + + site = FakeSite() + channel = FakeChannel() + + req = SynapseRequest(site, channel) + req.process = lambda: b"" + req.content = BytesIO(content) + req.requestReceived(method, path, b"1.1") + + return req, channel + + +def wait_until_result(clock, channel, timeout=100): + """ + Wait until the channel has a result. + """ + clock.run() + x = 0 + + while not channel.result: + x += 1 + + if x > timeout: + raise Exception("Timed out waiting for request to finish.") + + clock.advance(0.1) + + +class ThreadedMemoryReactorClock(MemoryReactorClock): + """ + A MemoryReactorClock that supports callFromThread. + """ + def callFromThread(self, callback, *args, **kwargs): + """ + Make the callback fire in the next reactor iteration. + """ + d = Deferred() + d.addCallback(lambda x: callback(*args, **kwargs)) + self.callLater(0, d.callback, True) + return d + + +def setup_test_homeserver(*args, **kwargs): + """ + Set up a synchronous test server, driven by the reactor used by + the homeserver. + """ + d = _sth(*args, **kwargs).result + + # Make the thread pool synchronous. + clock = d.get_clock() + pool = d.get_db_pool() + + def runWithConnection(func, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runWithConnection, + func, + *args, + **kwargs + ) + + def runInteraction(interaction, *args, **kwargs): + return threads.deferToThreadPool( + pool._reactor, + pool.threadpool, + pool._runInteraction, + interaction, + *args, + **kwargs + ) + + pool.runWithConnection = runWithConnection + pool.runInteraction = runInteraction + + class ThreadPool: + """ + Threadless thread pool. + """ + def start(self): + pass + + def callInThreadWithCallback(self, onResult, function, *args, **kwargs): + def _(res): + if isinstance(res, Failure): + onResult(False, res) + else: + onResult(True, res) + + d = Deferred() + d.addCallback(lambda x: function(*args, **kwargs)) + d.addBoth(_) + clock._reactor.callLater(0, d.callback, True) + return d + + clock.threadpool = ThreadPool() + pool.threadpool = ThreadPool() + return d diff --git a/tests/test_federation.py b/tests/test_federation.py new file mode 100644 index 0000000000..fc80a69369 --- /dev/null +++ b/tests/test_federation.py @@ -0,0 +1,243 @@ + +from twisted.internet.defer import succeed, maybeDeferred + +from synapse.util import Clock +from synapse.events import FrozenEvent +from synapse.types import Requester, UserID + +from tests import unittest +from tests.server import setup_test_homeserver, ThreadedMemoryReactorClock + +from mock import Mock + + +class MessageAcceptTests(unittest.TestCase): + def setUp(self): + + self.http_client = Mock() + self.reactor = ThreadedMemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=self.http_client, clock=self.hs_clock, reactor=self.reactor + ) + + user_id = UserID("us", "test") + our_user = Requester(user_id, None, False, None, None) + room_creator = self.homeserver.get_room_creation_handler() + room = room_creator.create_room( + our_user, room_creator.PRESETS_DICT["public_chat"], ratelimit=False + ) + self.reactor.advance(0.1) + self.room_id = self.successResultOf(room)["room_id"] + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + join_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "state_key": "@baduser:test.serv", + "event_id": "$join:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.member", + "origin": "test.servx", + "content": {"membership": "join"}, + "auth_events": [], + "prev_state": [(most_recent, {})], + "prev_events": [(most_recent, {})], + } + ) + + self.handler = self.homeserver.get_handlers().federation_handler + self.handler.do_auth = lambda *a, **b: succeed(True) + self.client = self.homeserver.get_federation_client() + self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( + pdus + ) + + # Send the join, it should return None (which is not an error) + d = self.handler.on_receive_pdu( + "test.serv", join_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + # Make sure we actually joined the room + self.assertEqual( + self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0], + "$join:test.serv", + ) + + def test_cant_hide_direct_ancestors(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + # If it asks us for new missing events, give them NOTHING + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return {"events": []} + + self.http_client.post_json = post_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Now lie about an event + lying_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("two:test.serv", {}), (most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", lying_event, sent_to_us_directly=True + ) + + # Step the reactor, so the database fetches come back + self.reactor.advance(1) + + # on_receive_pdu should throw an error + failure = self.failureResultOf(d) + self.assertEqual( + failure.value.args[0], + ( + "ERROR 403: Your server isn't divulging details about prev_events " + "referenced in this event." + ), + ) + + # Make sure the invalid event isn't there + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv") + + @unittest.DEBUG + def test_cant_hide_past_history(self): + """ + If you send a message, you must be able to provide the direct + prev_events that said event references. + """ + + def post_json(destination, path, data, headers=None, timeout=0): + if path.startswith("/_matrix/federation/v1/get_missing_events/"): + return { + "events": [ + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "three:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("four:test.serv", {})], + } + ] + } + + self.http_client.post_json = post_json + + def get_json(destination, path, args, headers=None): + if path.startswith("/_matrix/federation/v1/state_ids/"): + d = self.successResultOf( + self.homeserver.datastore.get_state_ids_for_event("one:test.serv") + ) + + return succeed( + { + "pdu_ids": [ + y + for x, y in d.items() + if x == ("m.room.member", "@us:test") + ], + "auth_chain_ids": d.values(), + } + ) + + self.http_client.get_json = get_json + + # Figure out what the most recent event is + most_recent = self.successResultOf( + maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + )[0] + + # Make a good event + good_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "one:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [(most_recent, {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", good_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + self.assertEqual(self.successResultOf(d), None) + + bad_event = FrozenEvent( + { + "room_id": self.room_id, + "sender": "@baduser:test.serv", + "event_id": "two:test.serv", + "depth": 1000, + "origin_server_ts": 1, + "type": "m.room.message", + "origin": "test.serv", + "content": "hewwo?", + "auth_events": [], + "prev_events": [("one:test.serv", {}), ("three:test.serv", {})], + } + ) + + d = self.handler.on_receive_pdu( + "test.serv", bad_event, sent_to_us_directly=True + ) + self.reactor.advance(1) + + extrem = maybeDeferred( + self.homeserver.datastore.get_latest_event_ids_in_room, self.room_id + ) + self.assertEqual(self.successResultOf(extrem)[0], "two:test.serv") + + state = self.homeserver.get_state_handler().get_current_state_ids(self.room_id) + self.reactor.advance(1) + self.assertIn(("m.room.member", "@us:test"), self.successResultOf(state).keys()) diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000000..8ad822c43b --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,128 @@ +import json +import re + +from twisted.internet.defer import Deferred +from twisted.test.proto_helpers import MemoryReactorClock + +from synapse.util import Clock +from synapse.api.errors import Codes, SynapseError +from synapse.http.server import JsonResource +from tests import unittest +from tests.server import make_request, setup_test_homeserver + + +class JsonResourceTests(unittest.TestCase): + def setUp(self): + self.reactor = MemoryReactorClock() + self.hs_clock = Clock(self.reactor) + self.homeserver = setup_test_homeserver( + http_client=None, clock=self.hs_clock, reactor=self.reactor + ) + + def test_handler_for_request(self): + """ + JsonResource.handler_for_request gives correctly decoded URL args to + the callback, while Twisted will give the raw bytes of URL query + arguments. + """ + got_kwargs = {} + + def _callback(request, **kwargs): + got_kwargs.update(kwargs) + return (200, kwargs) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo/(?P<room_id>[^/]*)$")], _callback) + + request, channel = make_request(b"GET", b"/foo/%E2%98%83?a=%E2%98%83") + request.render(res) + + self.assertEqual(request.args, {b'a': [u"\N{SNOWMAN}".encode('utf8')]}) + self.assertEqual(got_kwargs, {u"room_id": u"\N{SNOWMAN}"}) + + def test_callback_direct_exception(self): + """ + If the web callback raises an uncaught exception, it will be translated + into a 500. + """ + + def _callback(request, **kwargs): + raise Exception("boo") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'500') + + def test_callback_indirect_exception(self): + """ + If the web callback raises an uncaught exception in a Deferred, it will + be translated into a 500. + """ + + def _throw(*args): + raise Exception("boo") + + def _callback(request, **kwargs): + d = Deferred() + d.addCallback(_throw) + self.reactor.callLater(1, d.callback, True) + return d + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + # No error has been raised yet + self.assertTrue("code" not in channel.result) + + # Advance time, now there's an error + self.reactor.advance(1) + self.assertEqual(channel.result["code"], b'500') + + def test_callback_synapseerror(self): + """ + If the web callback raises a SynapseError, it returns the appropriate + status code and message set in it. + """ + + def _callback(request, **kwargs): + raise SynapseError(403, "Forbidden!!one!", Codes.FORBIDDEN) + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foo") + request.render(res) + + self.assertEqual(channel.result["code"], b'403') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Forbidden!!one!") + self.assertEqual(reply_body["errcode"], "M_FORBIDDEN") + + def test_no_handler(self): + """ + If there is no handler to process the request, Synapse will return 400. + """ + + def _callback(request, **kwargs): + """ + Not ever actually called! + """ + self.fail("shouldn't ever get here") + + res = JsonResource(self.homeserver) + res.register_paths("GET", [re.compile("^/foo$")], _callback) + + request, channel = make_request(b"GET", b"/foobar") + request.render(res) + + self.assertEqual(channel.result["code"], b'400') + reply_body = json.loads(channel.result["body"]) + self.assertEqual(reply_body["error"], "Unrecognized request") + self.assertEqual(reply_body["errcode"], "M_UNRECOGNIZED") diff --git a/tests/unittest.py b/tests/unittest.py index 184fe880f3..b25f2db5d5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -35,7 +35,10 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace('warning', 'warn') - self.tx_log.emit(twisted.logger.LogLevel.levelWithName(log_level), log_entry) + self.tx_log.emit( + twisted.logger.LogLevel.levelWithName(log_level), + log_entry.replace("{", r"(").replace("}", r")"), + ) handler = ToTwistedHandler() diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 24754591df..a94d566c96 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -19,13 +19,19 @@ import logging import mock from synapse.api.errors import SynapseError from synapse.util import logcontext -from twisted.internet import defer +from twisted.internet import defer, reactor from synapse.util.caches import descriptors from tests import unittest logger = logging.getLogger(__name__) +def run_on_reactor(): + d = defer.Deferred() + reactor.callLater(0, d.callback, 0) + return logcontext.make_deferred_yieldable(d) + + class CacheTestCase(unittest.TestCase): def test_invalidate_all(self): cache = descriptors.Cache("testcache") @@ -194,6 +200,8 @@ class DescriptorTestCase(unittest.TestCase): def fn(self, arg1): @defer.inlineCallbacks def inner_fn(): + # we want this to behave like an asynchronous function + yield run_on_reactor() raise SynapseError(400, "blah") return inner_fn() @@ -203,7 +211,12 @@ class DescriptorTestCase(unittest.TestCase): with logcontext.LoggingContext() as c1: c1.name = "c1" try: - yield obj.fn(1) + d = obj.fn(1) + self.assertEqual( + logcontext.LoggingContext.current_context(), + logcontext.LoggingContext.sentinel, + ) + yield d self.fail("No exception thrown") except SynapseError: pass |