summary refs log tree commit diff
path: root/tests/api/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api/test_auth.py')
-rw-r--r--tests/api/test_auth.py64
1 files changed, 40 insertions, 24 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 6121efcfa9..cc0b10e7f6 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self):
@@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase):
         request.getClientIP.return_value = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     @defer.inlineCallbacks
@@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase):
         request.getClientIP.return_value = "192.168.10.10"
         request.args[b"access_token"] = [self.test_token]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.assertEquals(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_appservice_valid_token_bad_ip(self):
@@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase):
         request.args[b"access_token"] = [self.test_token]
         request.args[b"user_id"] = [masquerading_user_id]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield self.auth.get_user_by_req(request)
+        requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request))
         self.assertEquals(
             requester.user.to_string(), masquerading_user_id.decode("utf8")
         )
@@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
-        user_info = yield self.auth.get_user_by_access_token(macaroon.serialize())
+        user_info = yield defer.ensureDeferred(
+            self.auth.get_user_by_access_token(macaroon.serialize())
+        )
         user = user_info["user"]
         self.assertEqual(UserID.from_string(user_id), user)
 
@@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase):
         macaroon.add_first_party_caveat("guest = true")
         serialized = macaroon.serialize()
 
-        user_info = yield self.auth.get_user_by_access_token(serialized)
+        user_info = yield defer.ensureDeferred(
+            self.auth.get_user_by_access_token(serialized)
+        )
         user = user_info["user"]
         is_guest = user_info["is_guest"]
         self.assertEqual(UserID.from_string(user_id), user)
@@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_cannot_use_regular_token_as_guest(self):
         USER_ID = "@percy:matrix.org"
-        self.store.add_access_token_to_user = Mock()
+        self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None))
+        self.store.get_device = Mock(return_value=defer.succeed(None))
 
-        token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id(
-            USER_ID, "DEVICE", valid_until_ms=None
+        token = yield defer.ensureDeferred(
+            self.hs.handlers.auth_handler.get_access_token_for_user_id(
+                USER_ID, "DEVICE", valid_until_ms=None
+            )
         )
         self.store.add_access_token_to_user.assert_called_with(
             USER_ID, token, "DEVICE", None
@@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase):
         request = Mock(args={})
         request.args[b"access_token"] = [token.encode("ascii")]
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+        requester = yield defer.ensureDeferred(
+            self.auth.get_user_by_req(request, allow_guest=True)
+        )
         self.assertEqual(UserID.from_string(USER_ID), requester.user)
         self.assertFalse(requester.is_guest)
 
@@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         with self.assertRaises(InvalidClientCredentialsError) as cm:
-            yield self.auth.get_user_by_req(request, allow_guest=True)
+            yield defer.ensureDeferred(
+                self.auth.get_user_by_req(request, allow_guest=True)
+            )
 
         self.assertEqual(401, cm.exception.code)
         self.assertEqual("Guest access token used for regular user", cm.exception.msg)
@@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase):
         small_number_of_users = 1
 
         # Ensure no error thrown
-        yield self.auth.check_auth_blocking()
+        yield defer.ensureDeferred(self.auth.check_auth_blocking())
 
         self.hs.config.limit_usage_by_mau = True
 
@@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase):
         )
 
         with self.assertRaises(ResourceLimitError) as e:
-            yield self.auth.check_auth_blocking()
+            yield defer.ensureDeferred(self.auth.check_auth_blocking())
         self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
         self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEquals(e.exception.code, 403)
@@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase):
         self.store.get_monthly_active_count = Mock(
             return_value=defer.succeed(small_number_of_users)
         )
-        yield self.auth.check_auth_blocking()
+        yield defer.ensureDeferred(self.auth.check_auth_blocking())
 
     @defer.inlineCallbacks
     def test_blocking_mau__depending_on_user_type(self):
@@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase):
 
         self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
         # Support users allowed
-        yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+        yield defer.ensureDeferred(
+            self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT)
+        )
         self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
         # Bots not allowed
         with self.assertRaises(ResourceLimitError):
-            yield self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+            yield defer.ensureDeferred(
+                self.auth.check_auth_blocking(user_type=UserTypes.BOT)
+            )
         self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100))
         # Real users not allowed
         with self.assertRaises(ResourceLimitError):
-            yield self.auth.check_auth_blocking()
+            yield defer.ensureDeferred(self.auth.check_auth_blocking())
 
     @defer.inlineCallbacks
     def test_reserved_threepid(self):
@@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase):
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.hs.config.mau_limits_reserved_threepids = [threepid]
 
-        yield self.store.register_user(user_id="user1", password_hash=None)
         with self.assertRaises(ResourceLimitError):
-            yield self.auth.check_auth_blocking()
+            yield defer.ensureDeferred(self.auth.check_auth_blocking())
 
         with self.assertRaises(ResourceLimitError):
-            yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+            yield defer.ensureDeferred(
+                self.auth.check_auth_blocking(threepid=unknown_threepid)
+            )
 
-        yield self.auth.check_auth_blocking(threepid=threepid)
+        yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid))
 
     @defer.inlineCallbacks
     def test_hs_disabled(self):
         self.hs.config.hs_disabled = True
         self.hs.config.hs_disabled_message = "Reason for being disabled"
         with self.assertRaises(ResourceLimitError) as e:
-            yield self.auth.check_auth_blocking()
+            yield defer.ensureDeferred(self.auth.check_auth_blocking())
         self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
         self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEquals(e.exception.code, 403)
@@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.config.hs_disabled = True
         self.hs.config.hs_disabled_message = "Reason for being disabled"
         with self.assertRaises(ResourceLimitError) as e:
-            yield self.auth.check_auth_blocking()
+            yield defer.ensureDeferred(self.auth.check_auth_blocking())
         self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
         self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEquals(e.exception.code, 403)
@@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase):
         user = "@user:server"
         self.hs.config.server_notices_mxid = user
         self.hs.config.hs_disabled_message = "Reason for being disabled"
-        yield self.auth.check_auth_blocking(user)
+        yield defer.ensureDeferred(self.auth.check_auth_blocking(user))