diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6121efcfa9..cc0b10e7f6 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_user_bad_token(self):
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "127.0.0.1"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
@defer.inlineCallbacks
@@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
request.getClientIP.return_value = "192.168.10.10"
request.args[b"access_token"] = [self.test_token]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(requester.user.to_string(), self.test_user)
def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
request.args[b"access_token"] = [self.test_token]
request.args[b"user_id"] = [masquerading_user_id]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request)
+ requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
self.assertEquals(
requester.user.to_string(), masquerading_user_id.decode("utf8")
)
@@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(macaroon.serialize())
+ )
user = user_info["user"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("guest = true")
serialized = macaroon.serialize()
- user_info = yield self.auth.get_user_by_access_token(serialized)
+ user_info = yield defer.ensureDeferred(
+ self.auth.get_user_by_access_token(serialized)
+ )
user = user_info["user"]
is_guest = user_info["is_guest"]
self.assertEqual(UserID.from_string(user_id), user)
@@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cannot_use_regular_token_as_guest(self):
USER_ID = "@percy:matrix.org"
- self.store.add_access_token_to_user = Mock()
+ self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
+ self.store.get_device = Mock(return_value=defer.succeed(None))
- token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
- USER_ID, "DEVICE", valid_until_ms=None
+ token = yield defer.ensureDeferred(
+ self.hs.handlers.auth_handler.get_access_token_for_user_id(
+ USER_ID, "DEVICE", valid_until_ms=None
+ )
)
self.store.add_access_token_to_user.assert_called_with(
USER_ID, token, "DEVICE", None
@@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
request = Mock(args={})
request.args[b"access_token"] = [token.encode("ascii")]
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
- requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+ requester = yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(UserID.from_string(USER_ID), requester.user)
self.assertFalse(requester.is_guest)
@@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
with self.assertRaises(InvalidClientCredentialsError) as cm:
- yield self.auth.get_user_by_req(request, allow_guest=True)
+ yield defer.ensureDeferred(
+ self.auth.get_user_by_req(request, allow_guest=True)
+ )
self.assertEqual(401, cm.exception.code)
self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
small_number_of_users = 1
# Ensure no error thrown
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.hs.config.limit_usage_by_mau = True
@@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(
return_value=defer.succeed(small_number_of_users)
)
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_blocking_mau__depending_on_user_type(self):
@@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Support users allowed
- yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Bots not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+ )
self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
# Real users not allowed
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
@defer.inlineCallbacks
def test_reserved_threepid(self):
@@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.hs.config.mau_limits_reserved_threepids = [threepid]
- yield self.store.register_user(user_id="user1", password_hash=None)
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
with self.assertRaises(ResourceLimitError):
- yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+ yield defer.ensureDeferred(
+ self.auth.check_auth_blocking(threepid=unknown_threepid)
+ )
- yield self.auth.check_auth_blocking(threepid=threepid)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
@defer.inlineCallbacks
def test_hs_disabled(self):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
self.hs.config.hs_disabled = True
self.hs.config.hs_disabled_message = "Reason for being disabled"
with self.assertRaises(ResourceLimitError) as e:
- yield self.auth.check_auth_blocking()
+ yield defer.ensureDeferred(self.auth.check_auth_blocking())
self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEquals(e.exception.code, 403)
@@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
user = "@user:server"
self.hs.config.server_notices_mxid = user
self.hs.config.hs_disabled_message = "Reason for being disabled"
- yield self.auth.check_auth_blocking(user)
+ yield defer.ensureDeferred(self.auth.check_auth_blocking(user))
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b03103d96f..52c4ac8b11 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase):
self.hs.clock.now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
self.hs.clock.now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- token
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
@defer.inlineCallbacks
@@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
macaroon = pymacaroons.Macaroon.deserialize(token)
- user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ user_id = yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
self.assertEqual("a_user", user_id)
@@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase):
macaroon.add_first_party_caveat("user_id = b_user")
with self.assertRaises(synapse.api.errors.AuthError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- macaroon.serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ macaroon.serialize()
+ )
)
@defer.inlineCallbacks
def test_mau_limits_disabled(self):
self.hs.config.limit_usage_by_mau = False
# Ensure does not throw exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase):
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.large_number_of_users)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
with self.assertRaises(ResourceLimitError):
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
# If in monthly active cohort
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
@@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().user_last_seen_monthly_active = Mock(
return_value=defer.succeed(self.hs.get_clock().time_msec())
@@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase):
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.hs.config.max_mau_value)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
@defer.inlineCallbacks
@@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase):
return_value=defer.succeed(self.small_number_of_users)
)
# Ensure does not raise exception
- yield self.auth_handler.get_access_token_for_user_id(
- "user_a", device_id=None, valid_until_ms=None
+ yield defer.ensureDeferred(
+ self.auth_handler.get_access_token_for_user_id(
+ "user_a", device_id=None, valid_until_ms=None
+ )
)
self.hs.get_datastore().get_monthly_active_count = Mock(
return_value=defer.succeed(self.small_number_of_users)
)
- yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
- self._get_macaroon().serialize()
+ yield defer.ensureDeferred(
+ self.auth_handler.validate_short_term_login_token_and_get_user_id(
+ self._get_macaroon().serialize()
+ )
)
def _get_macaroon(self):
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index e7b638dbfe..f1dc51d6c9 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
create_profile_with_displayname=user.localpart,
)
else:
- yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ yield defer.ensureDeferred(
+ self.hs.get_auth_handler().delete_access_tokens_for_user(user_id)
+ )
yield self.store.add_access_token_to_user(
user_id=user_id, token=token, device_id=None, valid_until_ms=None
diff --git a/tests/utils.py b/tests/utils.py
index 968d109f77..2079e0143d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -332,10 +332,15 @@ def setup_test_homeserver(
# Need to let the HS build an auth handler and then mess with it
# because AuthHandler's constructor requires the HS, so we can't make one
# beforehand and pass it in to the HS's constructor (chicken / egg)
- hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest()
- hs.get_auth_handler().validate_hash = (
- lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h
- )
+ async def hash(p):
+ return hashlib.md5(p.encode("utf8")).hexdigest()
+
+ hs.get_auth_handler().hash = hash
+
+ async def validate_hash(p, h):
+ return hashlib.md5(p.encode("utf8")).hexdigest() == h
+
+ hs.get_auth_handler().validate_hash = validate_hash
fed = kargs.get("resource_for_federation", None)
if fed:
|