summary refs log tree commit diff
path: root/tests/api/test_auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api/test_auth.py')
-rw-r--r--tests/api/test_auth.py126
1 files changed, 89 insertions, 37 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index 1faeb92f1e..dad86572df 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 import synapse.handlers.auth
 from synapse.api.auth import Auth
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes, ResourceLimitError
 from synapse.types import UserID
 
 from tests import unittest
@@ -34,13 +34,12 @@ class TestHandlers(object):
 
 
 class AuthTestCase(unittest.TestCase):
-
     @defer.inlineCallbacks
     def setUp(self):
         self.state_handler = Mock()
         self.store = Mock()
 
-        self.hs = yield setup_test_homeserver(handlers=None)
+        self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
         self.hs.get_datastore = Mock(return_value=self.store)
         self.hs.handlers = TestHandlers(self.hs)
         self.auth = Auth(self.hs)
@@ -53,11 +52,7 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
-        user_info = {
-            "name": self.test_user,
-            "token_id": "ditto",
-            "device_id": "device",
-        }
+        user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
@@ -76,10 +71,7 @@ class AuthTestCase(unittest.TestCase):
         self.failureResultOf(d, AuthError)
 
     def test_get_user_by_req_user_missing_token(self):
-        user_info = {
-            "name": self.test_user,
-            "token_id": "ditto",
-        }
+        user_info = {"name": self.test_user, "token_id": "ditto"}
         self.store.get_user_by_access_token = Mock(return_value=user_info)
 
         request = Mock(args={})
@@ -90,8 +82,7 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token(self):
         app_service = Mock(
-            token="foobar", url="a_url", sender=self.test_user,
-            ip_range_whitelist=None,
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         self.store.get_user_by_access_token = Mock(return_value=None)
@@ -106,8 +97,11 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_user_by_req_appservice_valid_token_good_ip(self):
         from netaddr import IPSet
+
         app_service = Mock(
-            token="foobar", url="a_url", sender=self.test_user,
+            token="foobar",
+            url="a_url",
+            sender=self.test_user,
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
@@ -122,8 +116,11 @@ class AuthTestCase(unittest.TestCase):
 
     def test_get_user_by_req_appservice_valid_token_bad_ip(self):
         from netaddr import IPSet
+
         app_service = Mock(
-            token="foobar", url="a_url", sender=self.test_user,
+            token="foobar",
+            url="a_url",
+            sender=self.test_user,
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
@@ -160,8 +157,7 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_by_req_appservice_valid_token_valid_user_id(self):
         masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
-            token="foobar", url="a_url", sender=self.test_user,
-            ip_range_whitelist=None,
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
@@ -174,15 +170,13 @@ class AuthTestCase(unittest.TestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
         requester = yield self.auth.get_user_by_req(request)
         self.assertEquals(
-            requester.user.to_string(),
-            masquerading_user_id.decode('utf8')
+            requester.user.to_string(), masquerading_user_id.decode('utf8')
         )
 
     def test_get_user_by_req_appservice_valid_token_bad_user_id(self):
         masquerading_user_id = b"@doppelganger:matrix.org"
         app_service = Mock(
-            token="foobar", url="a_url", sender=self.test_user,
-            ip_range_whitelist=None,
+            token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
@@ -199,17 +193,15 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = Mock(
-            return_value={
-                "name": "@baldrick:matrix.org",
-                "device_id": "device",
-            }
+            return_value={"name": "@baldrick:matrix.org", "device_id": "device"}
         )
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
             location=self.hs.config.server_name,
             identifier="key",
-            key=self.hs.config.macaroon_secret_key)
+            key=self.hs.config.macaroon_secret_key,
+        )
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@@ -223,16 +215,15 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
-        self.store.get_user_by_id = Mock(return_value={
-            "is_guest": True,
-        })
+        self.store.get_user_by_id = Mock(return_value={"is_guest": True})
         self.store.get_user_by_access_token = Mock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
             location=self.hs.config.server_name,
             identifier="key",
-            key=self.hs.config.macaroon_secret_key)
+            key=self.hs.config.macaroon_secret_key,
+        )
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = access")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
@@ -254,9 +245,7 @@ class AuthTestCase(unittest.TestCase):
         token = yield self.hs.handlers.auth_handler.issue_access_token(
             USER_ID, "DEVICE"
         )
-        self.store.add_access_token_to_user.assert_called_with(
-            USER_ID, token, "DEVICE"
-        )
+        self.store.add_access_token_to_user.assert_called_with(USER_ID, token, "DEVICE")
 
         def get_user(tok):
             if token != tok:
@@ -267,10 +256,9 @@ class AuthTestCase(unittest.TestCase):
                 "token_id": 1234,
                 "device_id": "DEVICE",
             }
+
         self.store.get_user_by_access_token = get_user
-        self.store.get_user_by_id = Mock(return_value={
-            "is_guest": False,
-        })
+        self.store.get_user_by_id = Mock(return_value={"is_guest": False})
 
         # check the token works
         request = Mock(args={})
@@ -297,3 +285,67 @@ class AuthTestCase(unittest.TestCase):
         self.assertEqual("Guest access token used for regular user", cm.exception.msg)
 
         self.store.get_user_by_id.assert_called_with(USER_ID)
+
+    @defer.inlineCallbacks
+    def test_blocking_mau(self):
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.max_mau_value = 50
+        lots_of_users = 100
+        small_number_of_users = 1
+
+        # Ensure no error thrown
+        yield self.auth.check_auth_blocking()
+
+        self.hs.config.limit_usage_by_mau = True
+
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(lots_of_users)
+        )
+
+        with self.assertRaises(ResourceLimitError) as e:
+            yield self.auth.check_auth_blocking()
+        self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEquals(e.exception.code, 403)
+
+        # Ensure does not throw an error
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(small_number_of_users)
+        )
+        yield self.auth.check_auth_blocking()
+
+    @defer.inlineCallbacks
+    def test_reserved_threepid(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.config.max_mau_value = 1
+        self.store.get_monthly_active_count = lambda: defer.succeed(2)
+        threepid = {'medium': 'email', 'address': 'reserved@server.com'}
+        unknown_threepid = {'medium': 'email', 'address': 'unreserved@server.com'}
+        self.hs.config.mau_limits_reserved_threepids = [threepid]
+
+        yield self.store.register(user_id='user1', token="123", password_hash=None)
+        with self.assertRaises(ResourceLimitError):
+            yield self.auth.check_auth_blocking()
+
+        with self.assertRaises(ResourceLimitError):
+            yield self.auth.check_auth_blocking(threepid=unknown_threepid)
+
+        yield self.auth.check_auth_blocking(threepid=threepid)
+
+    @defer.inlineCallbacks
+    def test_hs_disabled(self):
+        self.hs.config.hs_disabled = True
+        self.hs.config.hs_disabled_message = "Reason for being disabled"
+        with self.assertRaises(ResourceLimitError) as e:
+            yield self.auth.check_auth_blocking()
+        self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact)
+        self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
+        self.assertEquals(e.exception.code, 403)
+
+    @defer.inlineCallbacks
+    def test_server_notices_mxid_special_cased(self):
+        self.hs.config.hs_disabled = True
+        user = "@user:server"
+        self.hs.config.server_notices_mxid = user
+        self.hs.config.hs_disabled_message = "Reason for being disabled"
+        yield self.auth.check_auth_blocking(user)