diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 31f54bbd7d..758ee071a5 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,54 +12,53 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from twisted.internet import defer
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
-from synapse.handlers.sync import SyncConfig, SyncHandler
+from synapse.handlers.sync import SyncConfig
from synapse.types import UserID
import tests.unittest
import tests.utils
-from tests.utils import setup_test_homeserver
-class SyncTestCase(tests.unittest.TestCase):
+class SyncTestCase(tests.unittest.HomeserverTestCase):
""" Tests Sync Handler. """
- @defer.inlineCallbacks
- def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup)
- self.sync_handler = SyncHandler(self.hs)
+ def prepare(self, reactor, clock, hs):
+ self.hs = hs
+ self.sync_handler = self.hs.get_sync_handler()
self.store = self.hs.get_datastore()
- @defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
user_id1 = "@user1:server"
user_id2 = "@user2:server"
sync_config = self._generate_sync_config(user_id1)
+ self.reactor.advance(100) # So we get not 0 time
self.hs.config.limit_usage_by_mau = True
self.hs.config.max_mau_value = 1
# Check that the happy case does not throw errors
- yield self.store.upsert_monthly_active_user(user_id1)
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.get_success(self.store.upsert_monthly_active_user(user_id1))
+ self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
# Test that global lock works
self.hs.config.hs_disabled = True
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.hs.config.hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
- with self.assertRaises(ResourceLimitError) as e:
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+ e = self.get_failure(
+ self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ )
+ self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def _generate_sync_config(self, user_id):
return SyncConfig(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index f6d8660285..92b8726093 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -163,7 +163,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -227,7 +229,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -279,7 +283,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -300,7 +306,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
@@ -317,7 +325,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])])
self.assertEquals(self.event_source.get_current_key(), 2)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=1)
+ )
self.assertEquals(
events[0],
[{"type": "m.typing", "room_id": ROOM_ID, "content": {"user_ids": []}}],
@@ -335,7 +345,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.reset_mock()
self.assertEquals(self.event_source.get_current_key(), 3)
- events = self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ events = self.get_success(
+ self.event_source.get_new_events(room_ids=[ROOM_ID], from_key=0)
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 30fb77bac8..4bc3aaf02d 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -109,7 +109,9 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code)
self.assertEquals(self.event_source.get_current_key(), 1)
- events = self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ events = self.get_success(
+ self.event_source.get_new_events(from_key=0, room_ids=[self.room_id])
+ )
self.assertEquals(
events[0],
[
diff --git a/tests/unittest.py b/tests/unittest.py
index 68d245ec9f..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -18,6 +18,7 @@
import gc
import hashlib
import hmac
+import inspect
import logging
import time
@@ -25,7 +26,7 @@ from mock import Mock
from canonicaljson import json
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -417,6 +418,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -426,6 +429,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
|