diff --git a/tests/rest/client/v1/test_admin.py b/tests/rest/client/v1/test_admin.py
new file mode 100644
index 0000000000..8c90145601
--- /dev/null
+++ b/tests/rest/client/v1/test_admin.py
@@ -0,0 +1,305 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import hashlib
+import hmac
+import json
+
+from mock import Mock
+
+from synapse.http.server import JsonResource
+from synapse.rest.client.v1.admin import register_servlets
+from synapse.util import Clock
+
+from tests import unittest
+from tests.server import (
+ ThreadedMemoryReactorClock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
+
+
+class UserRegisterTestCase(unittest.TestCase):
+ def setUp(self):
+
+ self.clock = ThreadedMemoryReactorClock()
+ self.hs_clock = Clock(self.clock)
+ self.url = "/_matrix/client/r0/admin/register"
+
+ self.registration_handler = Mock()
+ self.identity_handler = Mock()
+ self.login_handler = Mock()
+ self.device_handler = Mock()
+ self.device_handler.check_device_registered = Mock(return_value="FAKE")
+
+ self.datastore = Mock(return_value=Mock())
+ self.datastore.get_current_state_deltas = Mock(return_value=[])
+
+ self.secrets = Mock()
+
+ self.hs = setup_test_homeserver(
+ http_client=None, clock=self.hs_clock, reactor=self.clock
+ )
+
+ self.hs.config.registration_shared_secret = u"shared"
+
+ self.hs.get_media_repository = Mock()
+ self.hs.get_deactivate_account_handler = Mock()
+
+ self.resource = JsonResource(self.hs)
+ register_servlets(self.hs, self.resource)
+
+ def test_disabled(self):
+ """
+ If there is no shared secret, registration through this method will be
+ prevented.
+ """
+ self.hs.config.registration_shared_secret = None
+
+ request, channel = make_request("POST", self.url, b'{}')
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(
+ 'Shared secret registration is not enabled', channel.json_body["error"]
+ )
+
+ def test_get_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, using the
+ homeserver's secrets provider.
+ """
+ secrets = Mock()
+ secrets.token_hex = Mock(return_value="abcd")
+
+ self.hs.get_secrets = Mock(return_value=secrets)
+
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(channel.json_body, {"nonce": "abcd"})
+
+ def test_expired_nonce(self):
+ """
+ Calling GET on the endpoint will return a randomised nonce, which will
+ only last for SALT_TIMEOUT (60s).
+ """
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+ nonce = channel.json_body["nonce"]
+
+ # 59 seconds
+ self.clock.advance(59)
+
+ body = json.dumps({"nonce": nonce})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('username must be specified', channel.json_body["error"])
+
+ # 61 seconds
+ self.clock.advance(2)
+
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('unrecognised nonce', channel.json_body["error"])
+
+ def test_register_incorrect_nonce(self):
+ """
+ Only the provided nonce can be used, as it's checked in the MAC.
+ """
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ ).encode('utf8')
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("HMAC incorrect", channel.json_body["error"])
+
+ def test_register_correct_nonce(self):
+ """
+ When the correct nonce is provided, and the right key is provided, the
+ user is registered.
+ """
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ ).encode('utf8')
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ def test_nonce_reuse(self):
+ """
+ A valid unrecognised nonce.
+ """
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+ nonce = channel.json_body["nonce"]
+
+ want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
+ want_mac.update(nonce.encode('ascii') + b"\x00bob\x00abc123\x00admin")
+ want_mac = want_mac.hexdigest()
+
+ body = json.dumps(
+ {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ ).encode('utf8')
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual("@bob:test", channel.json_body["user_id"])
+
+ # Now, try and reuse it
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('unrecognised nonce', channel.json_body["error"])
+
+ def test_missing_parts(self):
+ """
+ Synapse will complain if you don't give nonce, username, password, and
+ mac. Admin is optional. Additional checks are done for length and
+ type.
+ """
+ def nonce():
+ request, channel = make_request("GET", self.url)
+ render(request, self.resource, self.clock)
+ return channel.json_body["nonce"]
+
+ #
+ # Nonce check
+ #
+
+ # Must be present
+ body = json.dumps({})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('nonce must be specified', channel.json_body["error"])
+
+ #
+ # Username checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce()})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('username must be specified', channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": 1234})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid username', channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": b"abcd\x00"})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid username', channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid username', channel.json_body["error"])
+
+ #
+ # Username checks
+ #
+
+ # Must be present
+ body = json.dumps({"nonce": nonce(), "username": "a"})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('password must be specified', channel.json_body["error"])
+
+ # Must be a string
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid password', channel.json_body["error"])
+
+ # Must not have null bytes
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": b"abcd\x00"})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid password', channel.json_body["error"])
+
+ # Super long
+ body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
+ request, channel = make_request("POST", self.url, body.encode('utf8'))
+ render(request, self.resource, self.clock)
+
+ self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual('Invalid password', channel.json_body["error"])
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index a5af36a99c..50418153fa 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -14,100 +14,30 @@
# limitations under the License.
""" Tests REST events for /events paths."""
+
from mock import Mock, NonCallableMock
+from six import PY3
-# twisted imports
from twisted.internet import defer
-import synapse.rest.client.v1.events
-import synapse.rest.client.v1.register
-import synapse.rest.client.v1.room
-
-from tests import unittest
-
from ....utils import MockHttpResource, setup_test_homeserver
from .utils import RestTestCase
PATH_PREFIX = "/_matrix/client/api/v1"
-class EventStreamPaginationApiTestCase(unittest.TestCase):
- """ Tests event streaming query parameters and start/end keys used in the
- Pagination stream API. """
- user_id = "sid1"
-
- def setUp(self):
- # configure stream and inject items
- pass
-
- def tearDown(self):
- pass
-
- def TODO_test_long_poll(self):
- # stream from 'end' key, send (self+other) message, expect message.
-
- # stream from 'END', send (self+other) message, expect message.
-
- # stream from 'end' key, send (self+other) topic, expect topic.
-
- # stream from 'END', send (self+other) topic, expect topic.
-
- # stream from 'end' key, send (self+other) invite, expect invite.
-
- # stream from 'END', send (self+other) invite, expect invite.
-
- pass
-
- def TODO_test_stream_forward(self):
- # stream from START, expect injected items
-
- # stream from 'start' key, expect same content
-
- # stream from 'end' key, expect nothing
-
- # stream from 'END', expect nothing
-
- # The following is needed for cases where content is removed e.g. you
- # left a room, so the token you're streaming from is > the one that
- # would be returned naturally from START>END.
- # stream from very new token (higher than end key), expect same token
- # returned as end key
- pass
-
- def TODO_test_limits(self):
- # stream from a key, expect limit_num items
-
- # stream from START, expect limit_num items
-
- pass
-
- def TODO_test_range(self):
- # stream from key to key, expect X items
-
- # stream from key to END, expect X items
-
- # stream from START to key, expect X items
-
- # stream from START to END, expect all items
- pass
-
- def TODO_test_direction(self):
- # stream from END to START and fwds, expect newest first
-
- # stream from END to START and bwds, expect oldest first
-
- # stream from START to END and fwds, expect oldest first
-
- # stream from START to END and bwds, expect newest first
-
- pass
-
-
class EventStreamPermissionsTestCase(RestTestCase):
""" Tests event streaming (GET /events). """
+ if PY3:
+ skip = "Skip on Py3 until ported to use not V1 only register."
+
@defer.inlineCallbacks
def setUp(self):
+ import synapse.rest.client.v1.events
+ import synapse.rest.client.v1_only.register
+ import synapse.rest.client.v1.room
+
self.mock_resource = MockHttpResource(prefix=PATH_PREFIX)
hs = yield setup_test_homeserver(
@@ -125,7 +55,7 @@ class EventStreamPermissionsTestCase(RestTestCase):
hs.get_handlers().federation_handler = Mock()
- synapse.rest.client.v1.register.register_servlets(hs, self.mock_resource)
+ synapse.rest.client.v1_only.register.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.events.register_servlets(hs, self.mock_resource)
synapse.rest.client.v1.room.register_servlets(hs, self.mock_resource)
diff --git a/tests/rest/client/v1/test_register.py b/tests/rest/client/v1/test_register.py
index f15fb36213..83a23cd8fe 100644
--- a/tests/rest/client/v1/test_register.py
+++ b/tests/rest/client/v1/test_register.py
@@ -16,11 +16,12 @@
import json
from mock import Mock
+from six import PY3
from twisted.test.proto_helpers import MemoryReactorClock
from synapse.http.server import JsonResource
-from synapse.rest.client.v1.register import register_servlets
+from synapse.rest.client.v1_only.register import register_servlets
from synapse.util import Clock
from tests import unittest
@@ -31,6 +32,8 @@ class CreateUserServletTestCase(unittest.TestCase):
"""
Tests for CreateUserRestServlet.
"""
+ if PY3:
+ skip = "Not ported to Python 3."
def setUp(self):
self.registration_handler = Mock()
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 6b5764095e..00fc796787 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -20,7 +20,6 @@ import json
from mock import Mock, NonCallableMock
from six.moves.urllib import parse as urlparse
-# twisted imports
from twisted.internet import defer
import synapse.rest.client.v1.room
@@ -86,6 +85,7 @@ class RoomBase(unittest.TestCase):
self.resource = JsonResource(self.hs)
synapse.rest.client.v1.room.register_servlets(self.hs, self.resource)
+ synapse.rest.client.v1.room.register_deprecated_servlets(self.hs, self.resource)
self.helper = RestHelper(self.hs, self.resource, self.user_id)
diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py
index 5ea9cc825f..e890f0feac 100644
--- a/tests/rest/client/v2_alpha/test_filter.py
+++ b/tests/rest/client/v2_alpha/test_filter.py
@@ -21,8 +21,12 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
-from tests.server import make_request, setup_test_homeserver, wait_until_result
+from tests.server import (
+ ThreadedMemoryReactorClock as MemoryReactorClock,
+ make_request,
+ setup_test_homeserver,
+ wait_until_result,
+)
PATH_PREFIX = "/_matrix/client/v2_alpha"
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index 704cf97a40..03ec3993b2 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -20,8 +20,12 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
-from tests.server import ThreadedMemoryReactorClock as MemoryReactorClock
-from tests.server import make_request, setup_test_homeserver, wait_until_result
+from tests.server import (
+ ThreadedMemoryReactorClock as MemoryReactorClock,
+ make_request,
+ setup_test_homeserver,
+ wait_until_result,
+)
PATH_PREFIX = "/_matrix/client/v2_alpha"
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 04a88056f1..71d11cda77 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,8 +16,6 @@
from mock import Mock, patch
-from twisted.internet import defer
-
from synapse.util.distributor import Distributor
from . import unittest
@@ -27,38 +26,15 @@ class DistributorTestCase(unittest.TestCase):
def setUp(self):
self.dist = Distributor()
- @defer.inlineCallbacks
def test_signal_dispatch(self):
self.dist.declare("alert")
observer = Mock()
self.dist.observe("alert", observer)
- d = self.dist.fire("alert", 1, 2, 3)
- yield d
- self.assertTrue(d.called)
+ self.dist.fire("alert", 1, 2, 3)
observer.assert_called_with(1, 2, 3)
- @defer.inlineCallbacks
- def test_signal_dispatch_deferred(self):
- self.dist.declare("whine")
-
- d_inner = defer.Deferred()
-
- def observer():
- return d_inner
-
- self.dist.observe("whine", observer)
-
- d_outer = self.dist.fire("whine")
-
- self.assertFalse(d_outer.called)
-
- d_inner.callback(None)
- yield d_outer
- self.assertTrue(d_outer.called)
-
- @defer.inlineCallbacks
def test_signal_catch(self):
self.dist.declare("alarm")
@@ -71,9 +47,7 @@ class DistributorTestCase(unittest.TestCase):
with patch(
"synapse.util.distributor.logger", spec=["warning"]
) as mock_logger:
- d = self.dist.fire("alarm", "Go")
- yield d
- self.assertTrue(d.called)
+ self.dist.fire("alarm", "Go")
observers[0].assert_called_once_with("Go")
observers[1].assert_called_once_with("Go")
@@ -83,34 +57,12 @@ class DistributorTestCase(unittest.TestCase):
mock_logger.warning.call_args[0][0], str
)
- @defer.inlineCallbacks
- def test_signal_catch_no_suppress(self):
- # Gut-wrenching
- self.dist.suppress_failures = False
-
- self.dist.declare("whail")
-
- class MyException(Exception):
- pass
-
- @defer.inlineCallbacks
- def observer():
- raise MyException("Oopsie")
-
- self.dist.observe("whail", observer)
-
- d = self.dist.fire("whail")
-
- yield self.assertFailure(d, MyException)
- self.dist.suppress_failures = True
-
- @defer.inlineCallbacks
def test_signal_prereg(self):
observer = Mock()
self.dist.observe("flare", observer)
self.dist.declare("flare")
- yield self.dist.fire("flare", 4, 5)
+ self.dist.fire("flare", 4, 5)
observer.assert_called_with(4, 5)
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 159a136971..f40ff29b52 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
- @unittest.DEBUG
def test_cant_hide_past_history(self):
"""
If you send a message, you must be able to provide the direct
@@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items()
if x == ("m.room.member", "@us:test")
],
- "auth_chain_ids": d.values(),
+ "auth_chain_ids": list(d.values()),
}
)
diff --git a/tests/util/test_limiter.py b/tests/util/test_limiter.py
deleted file mode 100644
index a5a767b1ff..0000000000
--- a/tests/util/test_limiter.py
+++ /dev/null
@@ -1,70 +0,0 @@
-# -*- coding: utf-8 -*-
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.util.async import Limiter
-
-from tests import unittest
-
-
-class LimiterTestCase(unittest.TestCase):
-
- @defer.inlineCallbacks
- def test_limiter(self):
- limiter = Limiter(3)
-
- key = object()
-
- d1 = limiter.queue(key)
- cm1 = yield d1
-
- d2 = limiter.queue(key)
- cm2 = yield d2
-
- d3 = limiter.queue(key)
- cm3 = yield d3
-
- d4 = limiter.queue(key)
- self.assertFalse(d4.called)
-
- d5 = limiter.queue(key)
- self.assertFalse(d5.called)
-
- with cm1:
- self.assertFalse(d4.called)
- self.assertFalse(d5.called)
-
- self.assertTrue(d4.called)
- self.assertFalse(d5.called)
-
- with cm3:
- self.assertFalse(d5.called)
-
- self.assertTrue(d5.called)
-
- with cm2:
- pass
-
- with (yield d4):
- pass
-
- with (yield d5):
- pass
-
- d6 = limiter.queue(key)
- with (yield d6):
- pass
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index c95907b32c..4729bd5a0a 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,6 +17,7 @@
from six.moves import range
from twisted.internet import defer, reactor
+from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
@@ -65,3 +67,79 @@ class LinearizerTestCase(unittest.TestCase):
func(i)
return func(1000)
+
+ @defer.inlineCallbacks
+ def test_multiple_entries(self):
+ limiter = Linearizer(max_count=3)
+
+ key = object()
+
+ d1 = limiter.queue(key)
+ cm1 = yield d1
+
+ d2 = limiter.queue(key)
+ cm2 = yield d2
+
+ d3 = limiter.queue(key)
+ cm3 = yield d3
+
+ d4 = limiter.queue(key)
+ self.assertFalse(d4.called)
+
+ d5 = limiter.queue(key)
+ self.assertFalse(d5.called)
+
+ with cm1:
+ self.assertFalse(d4.called)
+ self.assertFalse(d5.called)
+
+ cm4 = yield d4
+ self.assertFalse(d5.called)
+
+ with cm3:
+ self.assertFalse(d5.called)
+
+ cm5 = yield d5
+
+ with cm2:
+ pass
+
+ with cm4:
+ pass
+
+ with cm5:
+ pass
+
+ d6 = limiter.queue(key)
+ with (yield d6):
+ pass
+
+ @defer.inlineCallbacks
+ def test_cancellation(self):
+ linearizer = Linearizer()
+
+ key = object()
+
+ d1 = linearizer.queue(key)
+ cm1 = yield d1
+
+ d2 = linearizer.queue(key)
+ self.assertFalse(d2.called)
+
+ d3 = linearizer.queue(key)
+ self.assertFalse(d3.called)
+
+ d2.cancel()
+
+ with cm1:
+ pass
+
+ self.assertTrue(d2.called)
+ try:
+ yield d2
+ self.fail("Expected d2 to raise CancelledError")
+ except CancelledError:
+ pass
+
+ with (yield d3):
+ pass
diff --git a/tests/utils.py b/tests/utils.py
index e488238bb3..c3dbff8507 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -71,6 +71,8 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
config.user_directory_search_all_users = False
config.user_consent_server_notice_content = None
config.block_events_without_consent_error = None
+ config.media_storage_providers = []
+ config.auto_join_rooms = []
# disable user directory updates, because they get done in the
# background, which upsets the test runner.
@@ -136,6 +138,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
+ reactor=reactor,
**kargs
)
|