diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 0bfb86bf1f..5d45689c8c 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -62,12 +62,15 @@ class AuthTestCase(unittest.TestCase):
# this is overridden for the appservice tests
self.store.get_app_service_by_token = Mock(return_value=None)
+ self.store.insert_client_ip = Mock(return_value=defer.succeed(None))
self.store.is_support_user = Mock(return_value=defer.succeed(False))
@defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
@@ -76,23 +79,25 @@ class AuthTestCase(unittest.TestCase):
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"}
- self.store.get_user_by_access_token = Mock(return_value=user_info)
+ self.store.get_user_by_access_token = Mock(
+ return_value=defer.succeed(user_info)
+ )
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -103,7 +108,7 @@ class AuthTestCase(unittest.TestCase):
token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -123,7 +128,7 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "192.168.10.10"
@@ -142,25 +147,25 @@ class AuthTestCase(unittest.TestCase):
ip_range_whitelist=IPSet(["192.168/16"]),
)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "131.111.8.42"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_appservice_bad_token(self):
self.store.get_app_service_by_token = Mock(return_value=None)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, InvalidClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
@@ -168,11 +173,11 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_by_req_appservice_missing_token(self):
app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
f = self.failureResultOf(d, MissingClientTokenError).value
self.assertEqual(f.code, 401)
self.assertEqual(f.errcode, "M_MISSING_TOKEN")
@@ -185,7 +190,11 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=True)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ # This just needs to return a truth-y value.
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
@@ -204,20 +213,22 @@ class AuthTestCase(unittest.TestCase):
)
app_service.is_interested_in_user = Mock(return_value=False)
self.store.get_app_service_by_token = Mock(return_value=app_service)
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
request = Mock(args={})
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- d = self.auth.get_user_by_req(request)
+ d = defer.ensureDeferred(self.auth.get_user_by_req(request))
self.failureResultOf(d, AuthError)
@defer.inlineCallbacks
def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock(
- return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
+ return_value=defer.succeed(
+ {"name": "@baldrick:matrix.org", "device_id": "device"}
+ )
)
user_id = "@baldrick:matrix.org"
@@ -241,8 +252,8 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self):
- self.store.get_user_by_id = Mock(return_value={"is_guest": True})
- self.store.get_user_by_access_token = Mock(return_value=None)
+ self.store.get_user_by_id = Mock(return_value=defer.succeed({"is_guest": True}))
+ self.store.get_user_by_access_token = Mock(return_value=defer.succeed(None))
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -282,16 +293,20 @@ class AuthTestCase(unittest.TestCase):
def get_user(tok):
if token != tok:
- return None
- return {
- "name": USER_ID,
- "is_guest": False,
- "token_id": 1234,
- "device_id": "DEVICE",
- }
+ return defer.succeed(None)
+ return defer.succeed(
+ {
+ "name": USER_ID,
+ "is_guest": False,
+ "token_id": 1234,
+ "device_id": "DEVICE",
+ }
+ )
self.store.get_user_by_access_token = get_user
- self.store.get_user_by_id = Mock(return_value={"is_guest": False})
+ self.store.get_user_by_id = Mock(
+ return_value=defer.succeed({"is_guest": False})
+ )
# check the token works
request = Mock(args={})
diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py
index 4e67503cf0..1fab1d6b69 100644
--- a/tests/api/test_filtering.py
+++ b/tests/api/test_filtering.py
@@ -375,8 +375,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.profile")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -396,8 +398,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart + "2", filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart + "2", filter_id=filter_id
+ )
)
results = user_filter.filter_presence(events=events)
@@ -412,8 +416,10 @@ class FilteringTestCase(unittest.TestCase):
event = MockEvent(sender="@foo:bar", type="m.room.topic", room_id="!foo:bar")
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events=events)
@@ -430,8 +436,10 @@ class FilteringTestCase(unittest.TestCase):
)
events = [event]
- user_filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ user_filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
results = user_filter.filter_room_state(events)
@@ -465,8 +473,10 @@ class FilteringTestCase(unittest.TestCase):
self.assertEquals(
user_filter_json,
(
- yield self.datastore.get_user_filter(
- user_localpart=user_localpart, filter_id=0
+ yield defer.ensureDeferred(
+ self.datastore.get_user_filter(
+ user_localpart=user_localpart, filter_id=0
+ )
)
),
)
@@ -479,8 +489,10 @@ class FilteringTestCase(unittest.TestCase):
user_localpart=user_localpart, user_filter=user_filter_json
)
- filter = yield self.filtering.get_user_filter(
- user_localpart=user_localpart, filter_id=filter_id
+ filter = yield defer.ensureDeferred(
+ self.filtering.get_user_filter(
+ user_localpart=user_localpart, filter_id=filter_id
+ )
)
self.assertEquals(filter.get_filter_json(), user_filter_json)
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 628f7d8db0..2a0b7c1b56 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self.mock_as_api.query_alias.return_value = make_awaitable(True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
+ self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
Mock(room_id=room_id, servers=servers)
)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5878f74175..b7d0adb10e 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -126,10 +126,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- def check_user_in_room(room_id, user_id):
+ async def check_user_in_room(room_id, user_id):
if user_id not in [u.to_string() for u in self.room_members]:
raise AuthError(401, "User is not in the room")
- return defer.succeed(None)
+ return None
hs.get_auth().check_user_in_room = check_user_in_room
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index f16eef15f7..17d0aae2e9 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -20,6 +20,8 @@ import urllib.parse
from mock import Mock
+from twisted.internet import defer
+
import synapse.rest.admin
from synapse.api.constants import UserTypes
from synapse.api.errors import HttpResponseException, ResourceLimitError
@@ -335,7 +337,9 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
store = self.hs.get_datastore()
# Set monthly active users to the limit
- store.get_monthly_active_count = Mock(return_value=self.hs.config.max_mau_value)
+ store.get_monthly_active_count = Mock(
+ return_value=defer.succeed(self.hs.config.max_mau_value)
+ )
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
self.get_failure(
@@ -588,7 +592,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
@@ -628,7 +632,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
# Set monthly active users to the limit
self.store.get_monthly_active_count = Mock(
- return_value=self.hs.config.max_mau_value
+ return_value=defer.succeed(self.hs.config.max_mau_value)
)
# Check that the blocking of monthly active users is working as expected
# The registration of a new user fails due to the limit
diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py
index 8df58b4a63..ace0a3c08d 100644
--- a/tests/rest/client/v1/test_profile.py
+++ b/tests/rest/client/v1/test_profile.py
@@ -70,8 +70,8 @@ class MockHandlerProfileTestCase(unittest.TestCase):
profile_handler=self.mock_handler,
)
- def _get_user_by_req(request=None, allow_guest=False):
- return defer.succeed(synapse.types.create_requester(myid))
+ async def _get_user_by_req(request=None, allow_guest=False):
+ return synapse.types.create_requester(myid)
hs.get_auth().get_user_by_req = _get_user_by_req
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 5ccda8b2bd..ef6b775ed2 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -23,8 +23,6 @@ from urllib import parse as urlparse
from mock import Mock
-from twisted.internet import defer
-
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes, Membership
from synapse.handlers.pagination import PurgeStatus
@@ -51,8 +49,8 @@ class RoomBase(unittest.HomeserverTestCase):
self.hs.get_federation_handler = Mock(return_value=Mock())
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
self.hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 18260bb90e..94d2bf2eb1 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -46,7 +46,7 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_handlers().federation_handler = Mock()
- def get_user_by_access_token(token=None, allow_guest=False):
+ async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.auth_user_id),
"token_id": 1,
@@ -55,8 +55,8 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
hs.get_auth().get_user_by_access_token = get_user_by_access_token
- def _insert_client_ip(*args, **kwargs):
- return defer.succeed(None)
+ async def _insert_client_ip(*args, **kwargs):
+ return None
hs.get_datastore().insert_client_ip = _insert_client_ip
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 51941f99f9..8933b560d2 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -165,26 +165,6 @@ class RestHelper(object):
return channel.json_body
- def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
- if txn_id is None:
- txn_id = "m%s" % (str(time.time()))
-
- path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
- if tok:
- path = path + "?access_token=%s" % tok
-
- request, channel = make_request(
- self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
- )
- render(request, self.resource, self.hs.get_reactor())
-
- assert int(channel.result["code"]) == expect_code, (
- "Expected: %d, got: %d, resp: %r"
- % (expect_code, int(channel.result["code"]), channel.result["body"])
- )
-
- return channel.json_body
-
def _read_write_state(
self,
room_id: str,
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 7deaf5b24a..53a43038f0 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -116,8 +116,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
self.assertEquals(channel.result["code"], b"200", channel.result)
self.assertDictContainsSubset(det_data, channel.json_body)
+ @override_config({"enable_registration": False})
def test_POST_disabled_registration(self):
- self.hs.config.enable_registration = False
request_data = json.dumps({"username": "kermit", "password": "monkey"})
self.auth_result = (None, {"username": "kermit", "password": "monkey"}, None)
diff --git a/tests/rest/client/v2_alpha/test_sync.py b/tests/rest/client/v2_alpha/test_sync.py
index a31e44c97e..fa3a3ec1bd 100644
--- a/tests/rest/client/v2_alpha/test_sync.py
+++ b/tests/rest/client/v2_alpha/test_sync.py
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
-from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
+from synapse.api.constants import EventContentFields, EventTypes
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import read_marker, sync
+from synapse.rest.client.v2_alpha import sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,156 +324,3 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
-
-
-class UnreadMessagesTestCase(unittest.HomeserverTestCase):
- servlets = [
- synapse.rest.admin.register_servlets,
- login.register_servlets,
- read_marker.register_servlets,
- room.register_servlets,
- sync.register_servlets,
- ]
-
- def prepare(self, reactor, clock, hs):
- self.url = "/sync?since=%s"
- self.next_batch = "s0"
-
- # Register the first user (used to check the unread counts).
- self.user_id = self.register_user("kermit", "monkey")
- self.tok = self.login("kermit", "monkey")
-
- # Create the room we'll check unread counts for.
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
-
- # Register the second user (used to send events to the room).
- self.user2 = self.register_user("kermit2", "monkey")
- self.tok2 = self.login("kermit2", "monkey")
-
- # Change the power levels of the room so that the second user can send state
- # events.
- self.helper.send_state(
- self.room_id,
- EventTypes.PowerLevels,
- {
- "users": {self.user_id: 100, self.user2: 100},
- "users_default": 0,
- "events": {
- "m.room.name": 50,
- "m.room.power_levels": 100,
- "m.room.history_visibility": 100,
- "m.room.canonical_alias": 50,
- "m.room.avatar": 50,
- "m.room.tombstone": 100,
- "m.room.server_acl": 100,
- "m.room.encryption": 100,
- },
- "events_default": 0,
- "state_default": 50,
- "ban": 50,
- "kick": 50,
- "redact": 50,
- "invite": 0,
- },
- tok=self.tok,
- )
-
- def test_unread_counts(self):
- """Tests that /sync returns the right value for the unread count (MSC2654)."""
-
- # Check that our own messages don't increase the unread count.
- self.helper.send(self.room_id, "hello", tok=self.tok)
- self._check_unread_count(0)
-
- # Join the new user and check that this doesn't increase the unread count.
- self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
- self._check_unread_count(0)
-
- # Check that the new user sending a message increases our unread count.
- res = self.helper.send(self.room_id, "hello", tok=self.tok2)
- self._check_unread_count(1)
-
- # Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
- request, channel = self.make_request(
- "POST",
- "/rooms/%s/read_markers" % self.room_id,
- body,
- access_token=self.tok,
- )
- self.render(request)
- self.assertEqual(channel.code, 200, channel.json_body)
-
- # Check that the unread counter is back to 0.
- self._check_unread_count(0)
-
- # Check that room name changes increase the unread counter.
- self.helper.send_state(
- self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
- )
- self._check_unread_count(1)
-
- # Check that room topic changes increase the unread counter.
- self.helper.send_state(
- self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
- )
- self._check_unread_count(2)
-
- # Check that encrypted messages increase the unread counter.
- self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
- self._check_unread_count(3)
-
- # Check that custom events with a body increase the unread counter.
- self.helper.send_event(
- self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that edits don't increase the unread counter.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={
- "body": "hello",
- "msgtype": "m.text",
- "m.relates_to": {"rel_type": RelationTypes.REPLACE},
- },
- tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that notices don't increase the unread counter.
- self.helper.send_event(
- room_id=self.room_id,
- type=EventTypes.Message,
- content={"body": "hello", "msgtype": "m.notice"},
- tok=self.tok2,
- )
- self._check_unread_count(4)
-
- # Check that tombstone events changes increase the unread counter.
- self.helper.send_state(
- self.room_id,
- EventTypes.Tombstone,
- {"replacement_room": "!someroom:test"},
- tok=self.tok2,
- )
- self._check_unread_count(5)
-
- def _check_unread_count(self, expected_count: True):
- """Syncs and compares the unread count with the expected value."""
-
- request, channel = self.make_request(
- "GET", self.url % self.next_batch, access_token=self.tok,
- )
- self.render(request)
-
- self.assertEqual(channel.code, 200, channel.json_body)
-
- room_entry = channel.json_body["rooms"]["join"][self.room_id]
- self.assertEqual(
- room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
- )
-
- # Store the next batch for the next request.
- self.next_batch = channel.json_body["next_batch"]
diff --git a/tests/rest/test_health.py b/tests/rest/test_health.py
new file mode 100644
index 0000000000..2d021f6565
--- /dev/null
+++ b/tests/rest/test_health.py
@@ -0,0 +1,34 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from synapse.rest.health import HealthResource
+
+from tests import unittest
+
+
+class HealthCheckTests(unittest.HomeserverTestCase):
+ def setUp(self):
+ super().setUp()
+
+ # replace the JsonResource with a HealthResource.
+ self.resource = HealthResource()
+
+ def test_health(self):
+ request, channel = self.make_request("GET", "/health", shorthand=False)
+ self.render(request)
+
+ self.assertEqual(request.code, 200)
+ self.assertEqual(channel.result["body"], b"OK")
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 7f70353b0d..3f88abe3d2 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -258,7 +258,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase):
self.user_id = "@user_id:test"
def test_server_notice_only_sent_once(self):
- self.store.get_monthly_active_count = Mock(return_value=1000)
+ self.store.get_monthly_active_count = Mock(return_value=defer.succeed(1000))
self.store.user_last_seen_monthly_active = Mock(
return_value=defer.succeed(1000)
diff --git a/tests/storage/test_directory.py b/tests/storage/test_directory.py
index 4e128e1047..daac947cb2 100644
--- a/tests/storage/test_directory.py
+++ b/tests/storage/test_directory.py
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_room_to_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertEquals(
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_alias_to_room(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
self.assertObjectHasAttributes(
{"room_id": self.room.to_string(), "servers": ["test"]},
- (yield self.store.get_association_from_room_alias(self.alias)),
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ ),
)
@defer.inlineCallbacks
def test_delete_alias(self):
- yield self.store.create_room_alias_association(
- room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ yield defer.ensureDeferred(
+ self.store.create_room_alias_association(
+ room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
+ )
)
- room_id = yield self.store.delete_room_alias(self.alias)
+ room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
self.assertEqual(self.room.to_string(), room_id)
self.assertIsNone(
- (yield self.store.get_association_from_room_alias(self.alias))
+ (
+ yield defer.ensureDeferred(
+ self.store.get_association_from_room_alias(self.alias)
+ )
+ )
)
diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py
index 398d546280..9f8d30373b 100644
--- a/tests/storage/test_end_to_end_keys.py
+++ b/tests/storage/test_end_to_end_keys.py
@@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json)
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user", "device", now, json)
yield self.store.store_device("user", "device", "display_name")
- res = yield self.store.get_e2e_device_keys((("user", "device"),))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user", "device"),))
+ )
self.assertIn("user", res)
self.assertIn("device", res["user"])
dev = res["user"]["device"]
@@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
- res = yield self.store.get_e2e_device_keys(
- (("user1", "device1"), ("user2", "device2"))
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 259f2215f1..e793781a26 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from synapse.api.constants import UserTypes
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
self.store.user_last_seen_monthly_active = Mock(
@@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self):
- self.store.upsert_monthly_active_user = Mock()
+ self.store.upsert_monthly_active_user = Mock(
+ side_effect=lambda user_id: make_awaitable(None)
+ )
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 6a545d2eb0..ecfafe68a9 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir(self):
# normally when alice searches the directory she should just find
# bob because bobby doesn't share a room with her.
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(1, len(r["results"]))
self.assertDictEqual(
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
def test_search_user_dir_all_users(self):
self.hs.config.user_directory_search_all_users = True
try:
- r = yield self.store.search_user_dir(ALICE, "bob", 10)
+ r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
self.assertFalse(r["limited"])
self.assertEqual(2, len(r["results"]))
self.assertDictEqual(
diff --git a/tests/unittest.py b/tests/unittest.py
index 2152c693f2..d0bba3ddef 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -241,20 +241,16 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
)
self.hs.get_auth().get_user_by_req = get_user_by_req
diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py
index 9e348694ad..bc42ffce88 100644
--- a/tests/util/test_retryutils.py
+++ b/tests/util/test_retryutils.py
@@ -26,9 +26,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
def test_new_destination(self):
"""A happy-path case with a new destination and a successful operation"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
# advance the clock a bit before making the request
self.pump(1)
@@ -36,18 +34,14 @@ class RetryLimiterTestCase(HomeserverTestCase):
with limiter:
pass
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
def test_limiter(self):
"""General test case which walks through the process of a failing request"""
store = self.hs.get_datastore()
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -58,29 +52,22 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], failure_ts)
self.assertEqual(new_timings["retry_interval"], MIN_RETRY_INTERVAL)
# now if we try again we should get a failure
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- self.failureResultOf(d, NotRetryingDestination)
+ self.get_failure(
+ get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination
+ )
#
# advance the clock and try again
#
self.pump(MIN_RETRY_INTERVAL)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
try:
@@ -91,12 +78,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
except AssertionError:
pass
- # wait for the update to land
- self.pump()
-
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertEqual(new_timings["failure_ts"], failure_ts)
self.assertEqual(new_timings["retry_last_ts"], retry_ts)
self.assertGreaterEqual(
@@ -110,9 +92,7 @@ class RetryLimiterTestCase(HomeserverTestCase):
# one more go, with success
#
self.pump(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0)
- d = get_retry_limiter("test_dest", self.clock, store)
- self.pump()
- limiter = self.successResultOf(d)
+ limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store))
self.pump(1)
with limiter:
@@ -121,7 +101,5 @@ class RetryLimiterTestCase(HomeserverTestCase):
# wait for the update to land
self.pump()
- d = store.get_destination_retry_timings("test_dest")
- self.pump()
- new_timings = self.successResultOf(d)
+ new_timings = self.get_success(store.get_destination_retry_timings("test_dest"))
self.assertIsNone(new_timings)
|