diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 149e443022..cc1c862ba4 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,10 +19,10 @@ import signedjson.sign
from mock import Mock
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.util import async, logcontext
+from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest, utils
-from twisted.internet import defer
+from twisted.internet import defer, reactor
class MockPerspectiveServer(object):
@@ -118,6 +118,7 @@ class KeyringTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_verify_json_objects_for_server_awaits_previous_requests(self):
+ clock = Clock(reactor)
key1 = signedjson.key.generate_signing_key(1)
kr = keyring.Keyring(self.hs)
@@ -167,7 +168,7 @@ class KeyringTestCase(unittest.TestCase):
# wait a tick for it to send the request to the perspectives server
# (it first tries the datastore)
- yield async.sleep(1) # XXX find out why this takes so long!
+ yield clock.sleep(1) # XXX find out why this takes so long!
self.http_client.post_json.assert_called_once()
self.assertIs(LoggingContext.current_context(), context_11)
@@ -183,7 +184,7 @@ class KeyringTestCase(unittest.TestCase):
res_deferreds_2 = kr.verify_json_objects_for_server(
[("server10", json1)],
)
- yield async.sleep(1)
+ yield clock.sleep(1)
self.http_client.post_json.assert_not_called()
res_deferreds_2[0].addBoth(self.check_context, None)
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index b5bc2fa255..6a757289db 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -1,9 +1,9 @@
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.rest.client.transactions import CLEANUP_PERIOD_MS
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from mock import Mock, call
-from synapse.util import async
+from synapse.util import Clock
from synapse.util.logcontext import LoggingContext
from tests import unittest
from tests.utils import MockClock
@@ -46,7 +46,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_logcontexts_with_async_result(self):
@defer.inlineCallbacks
def cb():
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
defer.returnValue("yay")
@defer.inlineCallbacks
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
index eef38b6781..c5e2f5549a 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -14,7 +14,7 @@
# limitations under the License.
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.media_storage import MediaStorage
@@ -38,6 +38,7 @@ class MediaStorageTests(unittest.TestCase):
self.secondary_base_path = os.path.join(self.test_dir, "secondary")
hs = Mock()
+ hs.get_reactor = Mock(return_value=reactor)
hs.config.media_store_path = self.primary_base_path
storage_providers = [FileStorageProviderBackend(
@@ -46,7 +47,7 @@ class MediaStorageTests(unittest.TestCase):
self.filepaths = MediaFilePaths(self.primary_base_path)
self.media_storage = MediaStorage(
- self.primary_base_path, self.filepaths, storage_providers,
+ hs, self.primary_base_path, self.filepaths, storage_providers,
)
def tearDown(self):
diff --git a/tests/test_distributor.py b/tests/test_distributor.py
index 010aeaee7e..c066381698 100644
--- a/tests/test_distributor.py
+++ b/tests/test_distributor.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
from mock import Mock, patch
from synapse.util.distributor import Distributor
-from synapse.util.async import run_on_reactor
class DistributorTestCase(unittest.TestCase):
@@ -95,7 +94,6 @@ class DistributorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def observer():
- yield run_on_reactor()
raise MyException("Oopsie")
self.dist.observe("whail", observer)
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
new file mode 100644
index 0000000000..d08e19c53a
--- /dev/null
+++ b/tests/test_event_auth.py
@@ -0,0 +1,151 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse import event_auth
+from synapse.api.errors import AuthError
+from synapse.events import FrozenEvent
+import unittest
+
+
+class EventAuthTestCase(unittest.TestCase):
+ def test_random_users_cannot_send_state_before_first_pl(self):
+ """
+ Check that, before the first PL lands, the creator is the only user
+ that can send a state event.
+ """
+ creator = "@creator:example.com"
+ joiner = "@joiner:example.com"
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.member", joiner): _join_event(joiner),
+ }
+
+ # creator should be able to send state
+ event_auth.check(
+ _random_state_event(creator), auth_events,
+ do_sig_check=False,
+ )
+
+ # joiner should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(joiner),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ def test_state_default_level(self):
+ """
+ Check that users above the state_default level can send state and
+ those below cannot
+ """
+ creator = "@creator:example.com"
+ pleb = "@joiner:example.com"
+ king = "@joiner2:example.com"
+
+ auth_events = {
+ ("m.room.create", ""): _create_event(creator),
+ ("m.room.member", creator): _join_event(creator),
+ ("m.room.power_levels", ""): _power_levels_event(creator, {
+ "state_default": "30",
+ "users": {
+ pleb: "29",
+ king: "30",
+ },
+ }),
+ ("m.room.member", pleb): _join_event(pleb),
+ ("m.room.member", king): _join_event(king),
+ }
+
+ # pleb should not be able to send state
+ self.assertRaises(
+ AuthError,
+ event_auth.check,
+ _random_state_event(pleb),
+ auth_events,
+ do_sig_check=False,
+ ),
+
+ # king should be able to send state
+ event_auth.check(
+ _random_state_event(king), auth_events,
+ do_sig_check=False,
+ )
+
+
+# helpers for making events
+
+TEST_ROOM_ID = "!test:room"
+
+
+def _create_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.create",
+ "sender": user_id,
+ "content": {
+ "creator": user_id,
+ },
+ })
+
+
+def _join_event(user_id):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+def _power_levels_event(sender, content):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "m.room.power_levels",
+ "sender": sender,
+ "state_key": "",
+ "content": content,
+ })
+
+
+def _random_state_event(sender):
+ return FrozenEvent({
+ "room_id": TEST_ROOM_ID,
+ "event_id": _get_event_id(),
+ "type": "test.state",
+ "sender": sender,
+ "state_key": "",
+ "content": {
+ "membership": "join",
+ },
+ })
+
+
+event_count = 0
+
+
+def _get_event_id():
+ global event_count
+ c = event_count
+ event_count += 1
+ return "!%i:example.com" % (c, )
diff --git a/tests/test_state.py b/tests/test_state.py
index a5c5e55951..71c412faf4 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -606,6 +606,14 @@ class StateTestCase(unittest.TestCase):
}
)
+ power_levels = create_event(
+ type=EventTypes.PowerLevels, state_key="",
+ content={"users": {
+ "@foo:bar": "100",
+ "@user_id:example.com": "100",
+ }}
+ )
+
creation = create_event(
type=EventTypes.Create, state_key="",
content={"creator": "@foo:bar"}
@@ -613,12 +621,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
@@ -633,7 +643,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@@ -641,12 +651,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
@@ -659,7 +671,7 @@ class StateTestCase(unittest.TestCase):
)
self.assertEqual(
- old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 2516fe40f4..24754591df 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -18,7 +18,6 @@ import logging
import mock
from synapse.api.errors import SynapseError
-from synapse.util import async
from synapse.util import logcontext
from twisted.internet import defer
from synapse.util.caches import descriptors
@@ -195,7 +194,6 @@ class DescriptorTestCase(unittest.TestCase):
def fn(self, arg1):
@defer.inlineCallbacks
def inner_fn():
- yield async.run_on_reactor()
raise SynapseError(400, "blah")
return inner_fn()
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index bc92f85fa6..543ac5bed9 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -32,7 +32,7 @@ class DictCacheTestCase(unittest.TestCase):
seq = self.cache.sequence
test_value = {"test": "test_simple_cache_hit_full"}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key)
self.assertEqual(test_value, c.value)
@@ -44,7 +44,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_hit_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test"])
self.assertEqual(test_value, c.value)
@@ -56,7 +56,7 @@ class DictCacheTestCase(unittest.TestCase):
test_value = {
"test": "test_simple_cache_miss_partial"
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({}, c.value)
@@ -70,7 +70,7 @@ class DictCacheTestCase(unittest.TestCase):
"test2": "test_simple_cache_hit_miss_partial2",
"test3": "test_simple_cache_hit_miss_partial3",
}
- self.cache.update(seq, key, test_value, full=True)
+ self.cache.update(seq, key, test_value)
c = self.cache.get(key, ["test2"])
self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value)
@@ -82,13 +82,13 @@ class DictCacheTestCase(unittest.TestCase):
test_value_1 = {
"test": "test_simple_cache_hit_miss_partial",
}
- self.cache.update(seq, key, test_value_1, full=False)
+ self.cache.update(seq, key, test_value_1, fetched_keys=set("test"))
seq = self.cache.sequence
test_value_2 = {
"test2": "test_simple_cache_hit_miss_partial2",
}
- self.cache.update(seq, key, test_value_2, full=False)
+ self.cache.update(seq, key, test_value_2, fetched_keys=set("test2"))
c = self.cache.get(key)
self.assertEqual(
diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py
index d6e1082779..c2aae8f54c 100644
--- a/tests/util/test_file_consumer.py
+++ b/tests/util/test_file_consumer.py
@@ -30,7 +30,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_pull_consumer(self):
string_file = StringIO()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = DummyPullProducer()
@@ -54,7 +54,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_consumer(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=[])
@@ -80,7 +80,7 @@ class FileConsumerTests(unittest.TestCase):
@defer.inlineCallbacks
def test_push_producer_feedback(self):
string_file = BlockingStringWrite()
- consumer = BackgroundFileConsumer(string_file)
+ consumer = BackgroundFileConsumer(string_file, reactor=reactor)
try:
producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"])
diff --git a/tests/util/test_linearizer.py b/tests/util/test_linearizer.py
index 4865eb4bc6..bf7e3aa885 100644
--- a/tests/util/test_linearizer.py
+++ b/tests/util/test_linearizer.py
@@ -12,10 +12,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.util import async, logcontext
+
+from synapse.util import logcontext, Clock
from tests import unittest
-from twisted.internet import defer
+from twisted.internet import defer, reactor
from synapse.util.async import Linearizer
from six.moves import range
@@ -53,7 +54,7 @@ class LinearizerTestCase(unittest.TestCase):
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
if sleep:
- yield async.sleep(0)
+ yield Clock(reactor).sleep(0)
self.assertEqual(
logcontext.LoggingContext.current_context(), lc)
diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py
index ad78d884e0..9cf90fcfc4 100644
--- a/tests/util/test_logcontext.py
+++ b/tests/util/test_logcontext.py
@@ -3,8 +3,7 @@ from twisted.internet import defer
from twisted.internet import reactor
from .. import unittest
-from synapse.util.async import sleep
-from synapse.util import logcontext
+from synapse.util import logcontext, Clock
from synapse.util.logcontext import LoggingContext
@@ -22,18 +21,20 @@ class LoggingContextTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_sleep(self):
+ clock = Clock(reactor)
+
@defer.inlineCallbacks
def competing_callback():
with LoggingContext() as competing_context:
competing_context.request = "competing"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("competing")
reactor.callLater(0, competing_callback)
with LoggingContext() as context_one:
context_one.request = "one"
- yield sleep(0)
+ yield clock.sleep(0)
self._check_test_key("one")
def _test_run_in_background(self, function):
@@ -87,7 +88,7 @@ class LoggingContextTestCase(unittest.TestCase):
def test_run_in_background_with_blocking_fn(self):
@defer.inlineCallbacks
def blocking_function():
- yield sleep(0)
+ yield Clock(reactor).sleep(0)
return self._test_run_in_background(blocking_function)
diff --git a/tests/utils.py b/tests/utils.py
index 262c4a5714..189fd2711c 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -37,11 +37,15 @@ USE_POSTGRES_FOR_TESTS = False
@defer.inlineCallbacks
-def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
+def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
+ **kargs):
"""Setup a homeserver suitable for running tests against. Keyword arguments
are passed to the Homeserver constructor. If no datastore is supplied a
datastore backed by an in-memory sqlite db will be given to the HS.
"""
+ if reactor is None:
+ from twisted.internet import reactor
+
if config is None:
config = Mock()
config.signing_key = [MockKey()]
@@ -110,6 +114,7 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
database_engine=db_engine,
room_list_handler=object(),
tls_server_context_factory=Mock(),
+ reactor=reactor,
**kargs
)
db_conn = hs.get_db_conn()
|