diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index bc75ddd3e9..dfcfaf79b6 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -19,6 +19,7 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.auth import Auth
+from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import UserTypes
from synapse.api.errors import (
AuthError,
@@ -49,7 +50,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
# AuthBlocking reads from the hs' config on initialization. We need to
# modify its config instead of the hs'
- self.auth_blocking = self.auth._auth_blocking
+ self.auth_blocking = AuthBlocking(hs)
self.test_user = "@foo:bar"
self.test_token = b"_test_token_"
@@ -312,9 +313,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.store.insert_client_ip.call_count, 2)
def test_get_user_from_macaroon(self):
- self.store.get_user_by_access_token = simple_async_mock(
- TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
- )
+ self.store.get_user_by_access_token = simple_async_mock(None)
user_id = "@baldrick:matrix.org"
macaroon = pymacaroons.Macaroon(
@@ -322,17 +321,14 @@ class AuthTestCase(unittest.HomeserverTestCase):
identifier="key",
key=self.hs.config.key.macaroon_secret_key,
)
+ # "Legacy" macaroons should not work for regular users not in the database
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("type = access")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
- user_info = self.get_success(
- self.auth.get_user_by_access_token(macaroon.serialize())
+ serialized = macaroon.serialize()
+ self.get_failure(
+ self.auth.get_user_by_access_token(serialized), InvalidClientTokenError
)
- self.assertEqual(user_id, user_info.user_id)
-
- # TODO: device_id should come from the macaroon, but currently comes
- # from the db.
- self.assertEqual(user_info.device_id, "device")
def test_get_guest_user_from_macaroon(self):
self.store.get_user_by_id = simple_async_mock({"is_guest": True})
@@ -362,20 +358,22 @@ class AuthTestCase(unittest.HomeserverTestCase):
small_number_of_users = 1
# Ensure no error thrown
- self.get_success(self.auth.check_auth_blocking())
+ self.get_success(self.auth_blocking.check_auth_blocking())
self.auth_blocking._limit_usage_by_mau = True
self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
# Ensure does not throw an error
self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
- self.get_success(self.auth.check_auth_blocking())
+ self.get_success(self.auth_blocking.check_auth_blocking())
def test_blocking_mau__depending_on_user_type(self):
self.auth_blocking._max_mau_value = 50
@@ -383,15 +381,18 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.store.get_monthly_active_count = simple_async_mock(100)
# Support users allowed
- self.get_success(self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT))
+ self.get_success(
+ self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
+ )
self.store.get_monthly_active_count = simple_async_mock(100)
# Bots not allowed
self.get_failure(
- self.auth.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
+ ResourceLimitError,
)
self.store.get_monthly_active_count = simple_async_mock(100)
# Real users not allowed
- self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
def test_blocking_mau__appservice_requester_allowed_when_not_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
@@ -419,7 +420,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
app_service=appservice,
authenticated_entity="@appservice:server",
)
- self.get_success(self.auth.check_auth_blocking(requester=requester))
+ self.get_success(self.auth_blocking.check_auth_blocking(requester=requester))
def test_blocking_mau__appservice_requester_disallowed_when_tracking_ips(self):
self.auth_blocking._max_mau_value = 50
@@ -448,7 +449,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
authenticated_entity="@appservice:server",
)
self.get_failure(
- self.auth.check_auth_blocking(requester=requester), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(requester=requester),
+ ResourceLimitError,
)
def test_reserved_threepid(self):
@@ -459,18 +461,21 @@ class AuthTestCase(unittest.HomeserverTestCase):
unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
self.auth_blocking._mau_limits_reserved_threepids = [threepid]
- self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
self.get_failure(
- self.auth.check_auth_blocking(threepid=unknown_threepid), ResourceLimitError
+ self.auth_blocking.check_auth_blocking(threepid=unknown_threepid),
+ ResourceLimitError,
)
- self.get_success(self.auth.check_auth_blocking(threepid=threepid))
+ self.get_success(self.auth_blocking.check_auth_blocking(threepid=threepid))
def test_hs_disabled(self):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
@@ -485,7 +490,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._hs_disabled = True
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError)
+ e = self.get_failure(
+ self.auth_blocking.check_auth_blocking(), ResourceLimitError
+ )
self.assertEqual(e.value.admin_contact, self.hs.config.server.admin_contact)
self.assertEqual(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.assertEqual(e.value.code, 403)
@@ -495,4 +502,4 @@ class AuthTestCase(unittest.HomeserverTestCase):
user = "@user:server"
self.auth_blocking._server_notices_mxid = user
self.auth_blocking._hs_disabled_message = "Reason for being disabled"
- self.get_success(self.auth.check_auth_blocking(user))
+ self.get_success(self.auth_blocking.check_auth_blocking(user))
|