diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9c62ec4374..170039fc82 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -775,17 +775,26 @@ class Auth(object):
)
@defer.inlineCallbacks
- def check_auth_blocking(self):
+ def check_auth_blocking(self, user_id=None):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
+
+ Args:
+ user_id(str): If present, checks for presence against existing MAU cohort
"""
if self.hs.config.hs_disabled:
raise AuthError(
403, self.hs.config.hs_disabled_message, errcode=Codes.HS_DISABLED
)
if self.hs.config.limit_usage_by_mau is True:
+ # If the user is already part of the MAU cohort
+ if user_id:
+ timestamp = yield self.store._user_last_seen_monthly_active(user_id)
+ if timestamp:
+ return
+ # Else if there is no room in the MAU bucket, bail
current_mau = yield self.store.get_monthly_active_count()
if current_mau >= self.hs.config.max_mau_value:
raise AuthError(
403, "MAU Limit Exceeded", errcode=Codes.MAU_LIMIT_EXCEEDED
- )
+ )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 776ddca638..d3b26a4106 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -208,7 +208,12 @@ class SyncHandler(object):
Returns:
Deferred[SyncResult]
"""
- yield self.auth.check_auth_blocking()
+ # If the user is not part of the mau group, then check that limits have
+ # not been exceeded (if not part of the group by this point, almost certain
+ # auth_blocking will occur)
+ user_id = sync_config.user.to_string()
+ yield self.auth.check_auth_blocking(user_id)
+
res = yield self.response_cache.wrap(
sync_config.request_key,
self._wait_for_sync_for_user,
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 497e4bd933..b95a8743a7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -13,8 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
+from synapse.api.errors import AuthError, Codes
-from synapse.api.errors import AuthError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig, SyncHandler
from synapse.types import UserID
@@ -31,19 +31,41 @@ class SyncTestCase(tests.unittest.TestCase):
def setUp(self):
self.hs = yield setup_test_homeserver()
self.sync_handler = SyncHandler(self.hs)
+ self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_wait_for_sync_for_user_auth_blocking(self):
- sync_config = SyncConfig(
- user=UserID("@user", "server"),
+
+ user_id1 = "@user1:server"
+ user_id2 = "@user2:server"
+ sync_config = self._generate_sync_config(user_id1)
+
+ self.hs.config.limit_usage_by_mau = True
+ self.hs.config.max_mau_value = 1
+
+ # Check that the happy case does not throw errors
+ yield self.store.upsert_monthly_active_user(user_id1)
+ yield self.sync_handler.wait_for_sync_for_user(sync_config)
+
+ # Test that global lock works
+ self.hs.config.hs_disabled = True
+ with self.assertRaises(AuthError) as e:
+ yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.assertEquals(e.exception.errcode, Codes.HS_DISABLED)
+
+ self.hs.config.hs_disabled = False
+
+ sync_config = self._generate_sync_config(user_id2)
+
+ with self.assertRaises(AuthError) as e:
+ yield self.sync_handler.wait_for_sync_for_user(sync_config)
+ self.assertEquals(e.exception.errcode, Codes.MAU_LIMIT_EXCEEDED)
+
+ def _generate_sync_config(self, user_id):
+ return SyncConfig(
+ user=UserID(user_id.split(":")[0][1:], user_id.split(":")[1]),
filter_collection=DEFAULT_FILTER_COLLECTION,
is_guest=False,
request_key="request_key",
device_id="device_id",
)
- # Ensure that an exception is not thrown
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
- self.hs.config.hs_disabled = True
-
- with self.assertRaises(AuthError):
- yield self.sync_handler.wait_for_sync_for_user(sync_config)
|