summary refs log tree commit diff
path: root/tests/api/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api/test_auth.py')
-rw-r--r--tests/api/test_auth.py42
1 files changed, 27 insertions, 15 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bc75ddd3e9..54af9089e9 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,6 +19,7 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import UserTypes
 from synapse.api.errors import (
     AuthError,
@@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
-        self.auth_blocking = self.auth._auth_blocking
+        self.auth_blocking = AuthBlocking(hs)
 
         self.test_user = "@foo:bar"
         self.test_token = b"_test_token_"
@@ -362,20 +363,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
         small_number_of_users = 1
 
         # Ensure no error thrown
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
         self.auth_blocking._limit_usage_by_mau = True
 
         self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
 
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
         self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
-        self.get_success(self.auth.check_auth_blocking())
+        self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self):
         self.auth_blocking._max_mau_value = 50
@@ -383,15 +386,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Support users allowed
-        self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+        self.get_success(
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
+        )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Bots not allowed
         self.get_failure(
-            self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
+            ResourceLimitError,
         )
         self.store.get_monthly_active_count = simple_async_mock(100)
         # Real users not allowed
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
     def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -419,7 +425,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             app_service=appservice,
             authenticated_entity="@appservice:server",
         )
-        self.get_success(self.auth.check_auth_blocking(requester=requester))
+        self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
 
     def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
         self.auth_blocking._max_mau_value = 50
@@ -448,7 +454,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
             authenticated_entity="@appservice:server",
         )
         self.get_failure(
-            self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(requester=requester),
+            ResourceLimitError,
         )
 
     def test_reserved_threepid(self):
@@ -459,18 +466,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]
 
-        self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
         self.get_failure(
-            self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+            self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
+            ResourceLimitError,
         )
 
-        self.get_success(self.auth.check_auth_blocking(threepid=threepid))
+        self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
 
     def test_hs_disabled(self):
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -485,7 +495,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._hs_disabled = True
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+        e = self.get_failure(
+            self.auth_blocking.check_auth_blocking(), ResourceLimitError
+        )
         self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
         self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
         self.assertEqual(e.value.code, 403)
@@ -495,4 +507,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
         user = "@user:server"
         self.auth_blocking._server_notices_mxid = user
         self.auth_blocking._hs_disabled_message = "Reason for being disabled"
-        self.get_success(self.auth.check_auth_blocking(user))
+        self.get_success(self.auth_blocking.check_auth_blocking(user))