summary refs log tree commit diff
diff options
context:
space:
mode:
authorNeil Johnson <neil@matrix.org>2018-08-02 16:57:35 +0100
committerNeil Johnson <neil@matrix.org>2018-08-02 16:57:35 +0100
commit74b1d46ad9ae692774f2e9d71cbbe1cea91b4070 (patch)
treed1790607bc3a8ae00db9cfa57efd79bf3fedaac1
parentremove unused count_monthly_users (diff)
downloadsynapse-74b1d46ad9ae692774f2e9d71cbbe1cea91b4070.tar.xz
do mau checks based on monthly_active_users table
Diffstat (limited to '')
-rw-r--r--synapse/api/auth.py13
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/register.py10
-rw-r--r--synapse/storage/client_ips.py15
-rw-r--r--tests/api/test_auth.py31
-rw-r--r--tests/handlers/test_auth.py8
-rw-r--r--tests/handlers/test_register.py71
7 files changed, 97 insertions, 61 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index d8022bcf8e..943a488339 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -773,3 +773,16 @@ class Auth(object):
             raise AuthError(
                 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
             )
+
+    @defer.inlineCallbacks
+    def check_auth_blocking(self, error):
+        """Checks if the user should be rejected for some external reason,
+        such as monthly active user limiting or global disable flag
+        Args:
+            error (Error): The error that should be raised if user is to be
+            blocked
+            """
+        if self.hs.config.limit_usage_by_mau is True:
+            current_mau = yield self.store.get_monthly_active_count()
+            if current_mau >= self.hs.config.max_mau_value:
+                raise error
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 184eef09d0..8f9cff92e8 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -913,12 +913,10 @@ class AuthHandler(BaseHandler):
         Ensure that if mau blocking is enabled that invalid users cannot
         log in.
         """
-        if self.hs.config.limit_usage_by_mau is True:
-            current_mau = yield self.store.count_monthly_users()
-            if current_mau >= self.hs.config.max_mau_value:
-                raise AuthError(
-                    403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
-                )
+        error = AuthError(
+            403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
+        )
+        yield self.auth.check_auth_blocking(error)
 
 
 @attr.s
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 289704b241..706ed8c292 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -540,9 +540,7 @@ class RegistrationHandler(BaseHandler):
         Do not accept registrations if monthly active user limits exceeded
          and limiting is enabled
         """
-        if self.hs.config.limit_usage_by_mau is True:
-            current_mau = yield self.store.count_monthly_users()
-            if current_mau >= self.hs.config.max_mau_value:
-                raise RegistrationError(
-                    403, "MAU Limit Exceeded", Codes.MAU_LIMIT_EXCEEDED
-                )
+        error = RegistrationError(
+            403, "Monthly Active User limits exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
+        )
+        yield self.auth.check_auth_blocking(error)
diff --git a/synapse/storage/client_ips.py b/synapse/storage/client_ips.py
index 506915a1ef..83d64d1563 100644
--- a/synapse/storage/client_ips.py
+++ b/synapse/storage/client_ips.py
@@ -97,21 +97,22 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
 
     @defer.inlineCallbacks
     def _populate_monthly_active_users(self, user_id):
+        """Checks on the state of monthly active user limits and optionally
+        add the user to the monthly active tables
+
+        Args:
+            user_id(str): the user_id to query
+        """
+
         store = self.hs.get_datastore()
-        print "entering _populate_monthly_active_users"
         if self.hs.config.limit_usage_by_mau:
-            print "self.hs.config.limit_usage_by_mau is TRUE"
             is_user_monthly_active = yield store.is_user_monthly_active(user_id)
-            print "is_user_monthly_active is %r" % is_user_monthly_active
             if is_user_monthly_active:
                 yield store.upsert_monthly_active_user(user_id)
             else:
                 count = yield store.get_monthly_active_count()
-                print "count is %d" % count
                 if count < self.hs.config.max_mau_value:
-                    print "count is less than self.hs.config.max_mau_value "
-                    res = yield store.upsert_monthly_active_user(user_id)
-                    print "upsert response is %r" % res
+                    yield store.upsert_monthly_active_user(user_id)
 
     def _update_client_ips_batch(self):
         def update():
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py
index a82d737e71..54bdf28663 100644
--- a/tests/api/test_auth.py
+++ b/tests/api/test_auth.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 import synapse.handlers.auth
 from synapse.api.auth import Auth
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, Codes
 from synapse.types import UserID
 
 from tests import unittest
@@ -444,3 +444,32 @@ class AuthTestCase(unittest.TestCase):
         self.assertEqual("Guest access token used for regular user", cm.exception.msg)
 
         self.store.get_user_by_id.assert_called_with(USER_ID)
+
+    @defer.inlineCallbacks
+    def test_blocking_mau(self):
+        self.hs.config.limit_usage_by_mau = False
+        self.hs.config.max_mau_value = 50
+        lots_of_users = 100
+        small_number_of_users = 1
+
+        error = AuthError(
+            403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
+        )
+
+        # Ensure no error thrown
+        yield self.auth.check_auth_blocking(error)
+
+        self.hs.config.limit_usage_by_mau = True
+
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(lots_of_users)
+        )
+
+        with self.assertRaises(AuthError):
+            yield self.auth.check_auth_blocking(error)
+
+        # Ensure does not throw an error
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(small_number_of_users)
+        )
+        yield self.auth.check_auth_blocking(error)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 55eab9e9cf..8a9bf2d5fd 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -132,14 +132,14 @@ class AuthTestCase(unittest.TestCase):
     @defer.inlineCallbacks
     def test_mau_limits_exceeded(self):
         self.hs.config.limit_usage_by_mau = True
-        self.hs.get_datastore().count_monthly_users = Mock(
+        self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.large_number_of_users)
         )
 
         with self.assertRaises(AuthError):
             yield self.auth_handler.get_access_token_for_user_id('user_a')
 
-        self.hs.get_datastore().count_monthly_users = Mock(
+        self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.large_number_of_users)
         )
         with self.assertRaises(AuthError):
@@ -151,13 +151,13 @@ class AuthTestCase(unittest.TestCase):
     def test_mau_limits_not_exceeded(self):
         self.hs.config.limit_usage_by_mau = True
 
-        self.hs.get_datastore().count_monthly_users = Mock(
+        self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.small_number_of_users)
         )
         # Ensure does not raise exception
         yield self.auth_handler.get_access_token_for_user_id('user_a')
 
-        self.hs.get_datastore().count_monthly_users = Mock(
+        self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=defer.succeed(self.small_number_of_users)
         )
         yield self.auth_handler.validate_short_term_login_token_and_get_user_id(
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 0937d71cf6..6b5b8b3772 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -50,6 +50,10 @@ class RegistrationTestCase(unittest.TestCase):
         self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
         self.hs.handlers = RegistrationHandlers(self.hs)
         self.handler = self.hs.get_handlers().registration_handler
+        self.store = self.hs.get_datastore()
+        self.hs.config.max_mau_value = 50
+        self.lots_of_users = 100
+        self.small_number_of_users = 1
 
     @defer.inlineCallbacks
     def test_user_is_created_and_logged_in_if_doesnt_exist(self):
@@ -80,51 +84,44 @@ class RegistrationTestCase(unittest.TestCase):
         self.assertEquals(result_token, 'secret')
 
     @defer.inlineCallbacks
-    def test_cannot_register_when_mau_limits_exceeded(self):
-        local_part = "someone"
-        display_name = "someone"
-        requester = create_requester("@as:test")
-        store = self.hs.get_datastore()
+    def test_mau_limits_when_disabled(self):
         self.hs.config.limit_usage_by_mau = False
-        self.hs.config.max_mau_value = 50
-        lots_of_users = 100
-        small_number_users = 1
-
-        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
-
         # Ensure does not throw exception
-        yield self.handler.get_or_create_user(requester, 'a', display_name)
+        yield self.handler.get_or_create_user("requester", 'a', "display_name")
 
+    @defer.inlineCallbacks
+    def test_get_or_create_user_mau_not_blocked(self):
         self.hs.config.limit_usage_by_mau = True
-
-        with self.assertRaises(RegistrationError):
-            yield self.handler.get_or_create_user(requester, 'b', display_name)
-
-        store.count_monthly_users = Mock(return_value=defer.succeed(small_number_users))
-
-        self._macaroon_mock_generator("another_secret")
-
+        self.store.count_monthly_users = Mock(
+            return_value=defer.succeed(self.small_number_of_users)
+        )
         # Ensure does not throw exception
-        yield self.handler.get_or_create_user("@neil:matrix.org", 'c', "Neil")
+        yield self.handler.get_or_create_user("@user:server", 'c', "User")
 
-        self._macaroon_mock_generator("another another secret")
-        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+    @defer.inlineCallbacks
+    def test_get_or_create_user_mau_blocked(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(self.lots_of_users)
+        )
 
         with self.assertRaises(RegistrationError):
-            yield self.handler.register(localpart=local_part)
+            yield self.handler.get_or_create_user("requester", 'b', "display_name")
 
-        self._macaroon_mock_generator("another another secret")
-        store.count_monthly_users = Mock(return_value=defer.succeed(lots_of_users))
+    @defer.inlineCallbacks
+    def test_register_mau_blocked(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(self.lots_of_users)
+        )
+        with self.assertRaises(RegistrationError):
+            yield self.handler.register(localpart="local_part")
 
+    @defer.inlineCallbacks
+    def test_register_saml2_mau_blocked(self):
+        self.hs.config.limit_usage_by_mau = True
+        self.store.get_monthly_active_count = Mock(
+            return_value=defer.succeed(self.lots_of_users)
+        )
         with self.assertRaises(RegistrationError):
-            yield self.handler.register_saml2(local_part)
-
-    def _macaroon_mock_generator(self, secret):
-        """
-        Reset macaroon generator in the case where the test creates multiple users
-        """
-        macaroon_generator = Mock(
-            generate_access_token=Mock(return_value=secret))
-        self.hs.get_macaroon_generator = Mock(return_value=macaroon_generator)
-        self.hs.handlers = RegistrationHandlers(self.hs)
-        self.handler = self.hs.get_handlers().registration_handler
+            yield self.handler.register_saml2(localpart="local_part")