summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-05-28 22:38:26 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-05-28 22:53:23 +0100
commit82eac22286c4be119d13c46e459d4e3dbcb2f59e (patch)
tree06d28b27a0946792fba52399e277665918339d80
parentRatelimiters are instantiated by the HomeServer class (diff)
downloadsynapse-82eac22286c4be119d13c46e459d4e3dbcb2f59e.tar.xz
Modify servlets to pull Ratelimiters from HomeServer class
-rw-r--r--synapse/config/ratelimiting.py8
-rw-r--r--synapse/handlers/_base.py56
-rw-r--r--synapse/handlers/auth.py10
-rw-r--r--synapse/handlers/message.py1
-rw-r--r--synapse/handlers/register.py2
-rw-r--r--synapse/rest/client/v1/login.py19
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/util/ratelimitutils.py2
8 files changed, 37 insertions, 63 deletions
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 4a3bfc4354..8e42d15fa4 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -14,9 +14,15 @@
 
 from ._base import Config
 
+from typing import Dict
+
 
 class RateLimitConfig(object):
-    def __init__(self, config, defaults={"per_second": 0.17, "burst_count": 3.0}):
+    def __init__(
+        self,
+        config: Dict[str, float],
+        defaults={"per_second": 0.17, "burst_count": 3.0},
+    ):
         self.per_second = config.get("per_second", defaults["per_second"])
         self.burst_count = config.get("burst_count", defaults["burst_count"])
 
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 3b781d9836..206702b6ad 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -19,7 +19,6 @@ from twisted.internet import defer
 
 import synapse.types
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
 from synapse.types import UserID
 
 logger = logging.getLogger(__name__)
@@ -44,11 +43,16 @@ class BaseHandler(object):
         self.notifier = hs.get_notifier()
         self.state_handler = hs.get_state_handler()
         self.distributor = hs.get_distributor()
-        self.ratelimiter = hs.get_ratelimiter()
-        self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
         self.clock = hs.get_clock()
         self.hs = hs
 
+        self.ratelimiter = None
+        self.request_ratelimiter = hs.get_request_ratelimiter()
+        self._rc_message = self.hs.config.rc_message
+
+        # If special admin redaction ratelimiting is disabled, this will be None
+        self.admin_redaction_ratelimiter = hs.get_admin_redaction_ratelimiter()
+
         self.server_name = hs.hostname
 
         self.event_builder_factory = hs.get_event_builder_factory()
@@ -83,48 +87,30 @@ class BaseHandler(object):
         if requester.app_service and not requester.app_service.is_rate_limited():
             return
 
+        messages_per_second = self._rc_message.per_second
+        burst_count = self._rc_message.burst_count
+
         # Check if there is a per user override in the DB.
         override = yield self.store.get_ratelimit_for_user(user_id)
         if override:
-            # If overriden with a null Hz then ratelimiting has been entirely
+            # If overridden with a null Hz then ratelimiting has been entirely
             # disabled for the user
             if not override.messages_per_second:
                 return
 
             messages_per_second = override.messages_per_second
             burst_count = override.burst_count
+
+        if is_admin_redaction and self.admin_redaction_ratelimiter:
+            # If we have separate config for admin redactions, use a separate
+            # ratelimiter as to not have user_id's clash
+            self.admin_redaction_ratelimiter.ratelimit(user_id, time_now, update)
         else:
-            # We default to different values if this is an admin redaction and
-            # the config is set
-            if is_admin_redaction and self.hs.config.rc_admin_redaction:
-                messages_per_second = self.hs.config.rc_admin_redaction.per_second
-                burst_count = self.hs.config.rc_admin_redaction.burst_count
-            else:
-                messages_per_second = self.hs.config.rc_message.per_second
-                burst_count = self.hs.config.rc_message.burst_count
-
-        if is_admin_redaction and self.hs.config.rc_admin_redaction:
-            # If we have separate config for admin redactions we use a separate
-            # ratelimiter
-            allowed, time_allowed = self.admin_redaction_ratelimiter.can_do_action(
-                user_id,
-                time_now,
-                rate_hz=messages_per_second,
-                burst_count=burst_count,
-                update=update,
-            )
-        else:
-            allowed, time_allowed = self.ratelimiter.can_do_action(
-                user_id,
-                time_now,
-                rate_hz=messages_per_second,
-                burst_count=burst_count,
-                update=update,
-            )
-        if not allowed:
-            raise LimitExceededError(
-                retry_after_ms=int(1000 * (time_allowed - time_now))
-            )
+            # Override rate and burst count per-user
+            self.request_ratelimiter.rate_hz = messages_per_second
+            self.request_ratelimiter.burst_count = burst_count
+
+            self.request_ratelimiter.ratelimit(user_id, time_now, update)
 
     async def maybe_kick_guest_users(self, event, context=None):
         # Technically this function invalidates current_state by changing it.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 75b39e878c..9aab4692f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -108,7 +108,11 @@ class AuthHandler(BaseHandler):
 
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
-        self._failed_uia_attempts_ratelimiter = Ratelimiter()
+        # XXX: Should this be hs.get_login_failed_attempts_ratelimiter?
+        self._failed_uia_attempts_ratelimiter = Ratelimiter(
+            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+        )
 
         self._clock = self.hs.get_clock()
 
@@ -199,8 +203,6 @@ class AuthHandler(BaseHandler):
         self._failed_uia_attempts_ratelimiter.ratelimit(
             user_id,
             time_now_s=self._clock.time(),
-            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
             update=False,
         )
 
@@ -216,8 +218,6 @@ class AuthHandler(BaseHandler):
             self._failed_uia_attempts_ratelimiter.can_do_action(
                 user_id,
                 time_now_s=self._clock.time(),
-                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-                burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
                 update=True,
             )
             raise
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 681f92cafd..649ca1f08a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -362,7 +362,6 @@ class EventCreationHandler(object):
         self.profile_handler = hs.get_profile_handler()
         self.event_builder_factory = hs.get_event_builder_factory()
         self.server_name = hs.hostname
-        self.ratelimiter = hs.get_ratelimiter()
         self.notifier = hs.get_notifier()
         self.config = hs.config
         self.require_membership_for_aliases = hs.config.require_membership_for_aliases
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a6178e74a1..99e2b3fb2c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -430,8 +430,6 @@ class RegistrationHandler(BaseHandler):
         self.ratelimiter.ratelimit(
             address,
             time_now_s=time_now,
-            rate_hz=self.hs.config.rc_registration.per_second,
-            burst_count=self.hs.config.rc_registration.burst_count,
         )
 
     def register_with_store(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d89b2e5532..2754a04669 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -89,9 +89,8 @@ class LoginRestServlet(RestServlet):
         self.handlers = hs.get_handlers()
         self._clock = hs.get_clock()
         self._well_known_builder = WellKnownBuilder(hs)
-        self._address_ratelimiter = Ratelimiter()
-        self._account_ratelimiter = Ratelimiter()
-        self._failed_attempts_ratelimiter = Ratelimiter()
+        self._account_ratelimiter = hs.get_login_ratelimiter()
+        self._failed_attempts_ratelimiter = hs.get_login_failed_attempts_ratelimiter()
 
     def on_GET(self, request):
         flows = []
@@ -129,11 +128,9 @@ class LoginRestServlet(RestServlet):
         return 200, {}
 
     async def on_POST(self, request):
-        self._address_ratelimiter.ratelimit(
+        self._account_ratelimiter.ratelimit(
             request.getClientIP(),
             time_now_s=self.hs.clock.time(),
-            rate_hz=self.hs.config.rc_login_address.per_second,
-            burst_count=self.hs.config.rc_login_address.burst_count,
             update=True,
         )
 
@@ -206,8 +203,6 @@ class LoginRestServlet(RestServlet):
             self._failed_attempts_ratelimiter.ratelimit(
                 (medium, address),
                 time_now_s=self._clock.time(),
-                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-                burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
                 update=False,
             )
 
@@ -246,8 +241,6 @@ class LoginRestServlet(RestServlet):
                 self._failed_attempts_ratelimiter.can_do_action(
                     (medium, address),
                     time_now_s=self._clock.time(),
-                    rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-                    burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
                     update=True,
                 )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@@ -270,8 +263,6 @@ class LoginRestServlet(RestServlet):
         self._failed_attempts_ratelimiter.ratelimit(
             qualified_user_id.lower(),
             time_now_s=self._clock.time(),
-            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
             update=False,
         )
 
@@ -287,8 +278,6 @@ class LoginRestServlet(RestServlet):
             self._failed_attempts_ratelimiter.can_do_action(
                 qualified_user_id.lower(),
                 time_now_s=self._clock.time(),
-                rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
-                burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
                 update=True,
             )
             raise
@@ -326,8 +315,6 @@ class LoginRestServlet(RestServlet):
         self._account_ratelimiter.ratelimit(
             user_id.lower(),
             time_now_s=self._clock.time(),
-            rate_hz=self.hs.config.rc_login_account.per_second,
-            burst_count=self.hs.config.rc_login_account.burst_count,
             update=True,
         )
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index addd4cae19..7800604938 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -401,8 +401,6 @@ class RegisterRestServlet(RestServlet):
         allowed, time_allowed = self.ratelimiter.can_do_action(
             client_addr,
             time_now_s=time_now,
-            rate_hz=self.hs.config.rc_registration.per_second,
-            burst_count=self.hs.config.rc_registration.burst_count,
             update=False,
         )
 
diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py
index 5ca4521ce3..e5efdfcd02 100644
--- a/synapse/util/ratelimitutils.py
+++ b/synapse/util/ratelimitutils.py
@@ -43,7 +43,7 @@ class FederationRateLimiter(object):
         self.ratelimiters = collections.defaultdict(new_limiter)
 
     def ratelimit(self, host):
-        """Used to ratelimit an incoming request from given host
+        """Used to ratelimit an incoming request from a given host
 
         Example usage: