summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py13
-rw-r--r--synapse/handlers/sync.py7
-rw-r--r--tests/handlers/test_sync.py40
3 files changed, 48 insertions, 12 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9c62ec4374..170039fc82 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -775,17 +775,26 @@ class Auth(object):
             )
 
     @defer.inlineCallbacks
-    def check_auth_blocking(self):
+    def check_auth_blocking(self, user_id=None):
         """Checks if the user should be rejected for some external reason,
         such as monthly active user limiting or global disable flag
+
+        Args:
+            user_id(str): If present, checks for presence against existing MAU cohort
         """
         if self.hs.config.hs_disabled:
             raise AuthError(
                 403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
             )
         if self.hs.config.limit_usage_by_mau is True:
+            # If the user is already part of the MAU cohort
+            if user_id:
+                timestamp = yield self.store._user_last_seen_monthly_active(user_id)
+                if timestamp:
+                    return
+            # Else if there is no room in the MAU bucket, bail
             current_mau = yield self.store.get_monthly_active_count()
             if current_mau >= self.hs.config.max_mau_value:
                 raise AuthError(
                     403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
-                )
+            )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 776ddca638..d3b26a4106 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -208,7 +208,12 @@ class SyncHandler(object):
         Returns:
             Deferred[SyncResult]
         """
-        yield self.auth.check_auth_blocking()
+        # If the user is not part of the mau group, then check that limits have
+        # not been exceeded (if not part of the group by this point, almost certain
+        # auth_blocking will occur)
+        user_id = sync_config.user.to_string()
+        yield self.auth.check_auth_blocking(user_id)
+
         res = yield self.response_cache.wrap(
             sync_config.request_key,
             self._wait_for_sync_for_user,
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 497e4bd933..b95a8743a7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -13,8 +13,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from twisted.internet import defer
+from synapse.api.errors import AuthError, Codes
 
-from synapse.api.errors import AuthError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
 from synapse.handlers.sync import SyncConfig, SyncHandler
 from synapse.types import UserID
@@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase):
     def setUp(self):
         self.hs = yield setup_test_homeserver()
         self.sync_handler = SyncHandler(self.hs)
+        self.store = self.hs.get_datastore()
 
     @defer.inlineCallbacks
     def test_wait_for_sync_for_user_auth_blocking(self):
-        sync_config = SyncConfig(
-            user=UserID("@user", "server"),
+
+        user_id1 = "@user1:server"
+        user_id2 = "@user2:server"
+        sync_config = self._generate_sync_config(user_id1)
+
+        self.hs.config.limit_usage_by_mau = True
+        self.hs.config.max_mau_value = 1
+
+        # Check that the happy case does not throw errors
+        yield self.store.upsert_monthly_active_user(user_id1)
+        yield self.sync_handler.wait_for_sync_for_user(sync_config)
+
+        # Test that global lock works
+        self.hs.config.hs_disabled = True
+        with self.assertRaises(AuthError) as e:
+            yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
+
+        self.hs.config.hs_disabled = False
+
+        sync_config = self._generate_sync_config(user_id2)
+
+        with self.assertRaises(AuthError) as e:
+            yield self.sync_handler.wait_for_sync_for_user(sync_config)
+        self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
+
+    def _generate_sync_config(self, user_id):
+        return SyncConfig(
+            user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
             filter_collection=DEFAULT_FILTER_COLLECTION,
             is_guest=False,
             request_key="request_key",
             device_id="device_id",
         )
-        # Ensure that an exception is not thrown
-        yield self.sync_handler.wait_for_sync_for_user(sync_config)
-        self.hs.config.hs_disabled = True
-
-        with self.assertRaises(AuthError):
-            yield self.sync_handler.wait_for_sync_for_user(sync_config)