summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9711.bugfix1
-rw-r--r--synapse/api/ratelimiting.py100
-rw-r--r--synapse/federation/federation_server.py5
-rw-r--r--synapse/handlers/_base.py14
-rw-r--r--synapse/handlers/auth.py24
-rw-r--r--synapse/handlers/devicemessage.py5
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/identity.py12
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/room_member.py23
-rw-r--r--synapse/replication/http/register.py2
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/client/v2_alpha/account.py10
-rw-r--r--synapse/rest/client/v2_alpha/register.py8
-rw-r--r--synapse/server.py1
-rw-r--r--tests/api/test_ratelimiting.py168
16 files changed, 241 insertions, 154 deletions
diff --git a/changelog.d/9711.bugfix b/changelog.d/9711.bugfix
new file mode 100644
index 0000000000..4ca3438d46
--- /dev/null
+++ b/changelog.d/9711.bugfix
@@ -0,0 +1 @@
+Fix recently added ratelimits to correctly honour the application service `rate_limited` flag.
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index c3f07bc1a3..2244b8a340 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -17,6 +17,7 @@ from collections import OrderedDict
 from typing import Hashable, Optional, Tuple
 
 from synapse.api.errors import LimitExceededError
+from synapse.storage.databases.main import DataStore
 from synapse.types import Requester
 from synapse.util import Clock
 
@@ -31,10 +32,13 @@ class Ratelimiter:
         burst_count: How many actions that can be performed before being limited.
     """
 
-    def __init__(self, clock: Clock, rate_hz: float, burst_count: int):
+    def __init__(
+        self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int
+    ):
         self.clock = clock
         self.rate_hz = rate_hz
         self.burst_count = burst_count
+        self.store = store
 
         # A ordered dictionary keeping track of actions, when they were last
         # performed and how often. Each entry is a mapping from a key of arbitrary type
@@ -46,45 +50,10 @@ class Ratelimiter:
             OrderedDict()
         )  # type: OrderedDict[Hashable, Tuple[float, int, float]]
 
-    def can_requester_do_action(
-        self,
-        requester: Requester,
-        rate_hz: Optional[float] = None,
-        burst_count: Optional[int] = None,
-        update: bool = True,
-        _time_now_s: Optional[int] = None,
-    ) -> Tuple[bool, float]:
-        """Can the requester perform the action?
-
-        Args:
-            requester: The requester to key off when rate limiting. The user property
-                will be used.
-            rate_hz: The long term number of actions that can be performed in a second.
-                Overrides the value set during instantiation if set.
-            burst_count: How many actions that can be performed before being limited.
-                Overrides the value set during instantiation if set.
-            update: Whether to count this check as performing the action
-            _time_now_s: The current time. Optional, defaults to the current time according
-                to self.clock. Only used by tests.
-
-        Returns:
-            A tuple containing:
-                * A bool indicating if they can perform the action now
-                * The reactor timestamp for when the action can be performed next.
-                  -1 if rate_hz is less than or equal to zero
-        """
-        # Disable rate limiting of users belonging to any AS that is configured
-        # not to be rate limited in its registration file (rate_limited: true|false).
-        if requester.app_service and not requester.app_service.is_rate_limited():
-            return True, -1.0
-
-        return self.can_do_action(
-            requester.user.to_string(), rate_hz, burst_count, update, _time_now_s
-        )
-
-    def can_do_action(
+    async def can_do_action(
         self,
-        key: Hashable,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
         rate_hz: Optional[float] = None,
         burst_count: Optional[int] = None,
         update: bool = True,
@@ -92,9 +61,16 @@ class Ratelimiter:
     ) -> Tuple[bool, float]:
         """Can the entity (e.g. user or IP address) perform the action?
 
+        Checks if the user has ratelimiting disabled in the database by looking
+        for null/zero values in the `ratelimit_override` table. (Non-zero
+        values aren't honoured, as they're specific to the event sending
+        ratelimiter, rather than all ratelimiters)
+
         Args:
-            key: The key we should use when rate limiting. Can be a user ID
-                (when sending events), an IP address, etc.
+            requester: The requester that is doing the action, if any. Used to check
+                if the user has ratelimits disabled in the database.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
             rate_hz: The long term number of actions that can be performed in a second.
                 Overrides the value set during instantiation if set.
             burst_count: How many actions that can be performed before being limited.
@@ -109,6 +85,30 @@ class Ratelimiter:
                 * The reactor timestamp for when the action can be performed next.
                   -1 if rate_hz is less than or equal to zero
         """
+        if key is None:
+            if not requester:
+                raise ValueError("Must supply at least one of `requester` or `key`")
+
+            key = requester.user.to_string()
+
+        if requester:
+            # Disable rate limiting of users belonging to any AS that is configured
+            # not to be rate limited in its registration file (rate_limited: true|false).
+            if requester.app_service and not requester.app_service.is_rate_limited():
+                return True, -1.0
+
+            # Check if ratelimiting has been disabled for the user.
+            #
+            # Note that we don't use the returned rate/burst count, as the table
+            # is specifically for the event sending ratelimiter. Instead, we
+            # only use it to (somewhat cheekily) infer whether the user should
+            # be subject to any rate limiting or not.
+            override = await self.store.get_ratelimit_for_user(
+                requester.authenticated_entity
+            )
+            if override and not override.messages_per_second:
+                return True, -1.0
+
         # Override default values if set
         time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
         rate_hz = rate_hz if rate_hz is not None else self.rate_hz
@@ -175,9 +175,10 @@ class Ratelimiter:
             else:
                 del self.actions[key]
 
-    def ratelimit(
+    async def ratelimit(
         self,
-        key: Hashable,
+        requester: Optional[Requester],
+        key: Optional[Hashable] = None,
         rate_hz: Optional[float] = None,
         burst_count: Optional[int] = None,
         update: bool = True,
@@ -185,8 +186,16 @@ class Ratelimiter:
     ):
         """Checks if an action can be performed. If not, raises a LimitExceededError
 
+        Checks if the user has ratelimiting disabled in the database by looking
+        for null/zero values in the `ratelimit_override` table. (Non-zero
+        values aren't honoured, as they're specific to the event sending
+        ratelimiter, rather than all ratelimiters)
+
         Args:
-            key: An arbitrary key used to classify an action
+            requester: The requester that is doing the action, if any. Used to check for
+                if the user has ratelimits disabled.
+            key: An arbitrary key used to classify an action. Defaults to the
+                requester's user ID.
             rate_hz: The long term number of actions that can be performed in a second.
                 Overrides the value set during instantiation if set.
             burst_count: How many actions that can be performed before being limited.
@@ -201,7 +210,8 @@ class Ratelimiter:
         """
         time_now_s = _time_now_s if _time_now_s is not None else self.clock.time()
 
-        allowed, time_allowed = self.can_do_action(
+        allowed, time_allowed = await self.can_do_action(
+            requester,
             key,
             rate_hz=rate_hz,
             burst_count=burst_count,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index d84e362070..71cb120ef7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -870,6 +870,7 @@ class FederationHandlerRegistry:
 
         # A rate limiter for incoming room key requests per origin.
         self._room_key_request_rate_limiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=self.clock,
             rate_hz=self.config.rc_key_requests.per_second,
             burst_count=self.config.rc_key_requests.burst_count,
@@ -930,7 +931,9 @@ class FederationHandlerRegistry:
         # the limit, drop them.
         if (
             edu_type == EduTypes.RoomKeyRequest
-            and not self._room_key_request_rate_limiter.can_do_action(origin)
+            and not await self._room_key_request_rate_limiter.can_do_action(
+                None, origin
+            )
         ):
             return
 
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index aade2c4a3a..fb899aa90d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -49,7 +49,7 @@ class BaseHandler:
 
         # The rate_hz and burst_count are overridden on a per-user basis
         self.request_ratelimiter = Ratelimiter(
-            clock=self.clock, rate_hz=0, burst_count=0
+            store=self.store, clock=self.clock, rate_hz=0, burst_count=0
         )
         self._rc_message = self.hs.config.rc_message
 
@@ -57,6 +57,7 @@ class BaseHandler:
         # by the presence of rate limits in the config
         if self.hs.config.rc_admin_redaction:
             self.admin_redaction_ratelimiter = Ratelimiter(
+                store=self.store,
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
@@ -91,11 +92,6 @@ class BaseHandler:
         if app_service is not None:
             return  # do not ratelimit app service senders
 
-        # Disable rate limiting of users belonging to any AS that is configured
-        # not to be rate limited in its registration file (rate_limited: true|false).
-        if requester.app_service and not requester.app_service.is_rate_limited():
-            return
-
         messages_per_second = self._rc_message.per_second
         burst_count = self._rc_message.burst_count
 
@@ -113,11 +109,11 @@ class BaseHandler:
         if is_admin_redaction and self.admin_redaction_ratelimiter:
             # If we have separate config for admin redactions, use a separate
             # ratelimiter as to not have user_ids clash
-            self.admin_redaction_ratelimiter.ratelimit(user_id, update=update)
+            await self.admin_redaction_ratelimiter.ratelimit(requester, update=update)
         else:
             # Override rate and burst count per-user
-            self.request_ratelimiter.ratelimit(
-                user_id,
+            await self.request_ratelimiter.ratelimit(
+                requester,
                 rate_hz=messages_per_second,
                 burst_count=burst_count,
                 update=update,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d537ea8137..08e413bc98 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -238,6 +238,7 @@ class AuthHandler(BaseHandler):
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -248,6 +249,7 @@ class AuthHandler(BaseHandler):
 
         # Ratelimitier for failed /login attempts
         self._failed_login_attempts_ratelimiter = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
@@ -352,7 +354,7 @@ class AuthHandler(BaseHandler):
         requester_user_id = requester.user.to_string()
 
         # Check if we should be ratelimited due to too many previous failed attempts
-        self._failed_uia_attempts_ratelimiter.ratelimit(requester_user_id, update=False)
+        await self._failed_uia_attempts_ratelimiter.ratelimit(requester, update=False)
 
         # build a list of supported flows
         supported_ui_auth_types = await self._get_available_ui_auth_types(
@@ -373,7 +375,9 @@ class AuthHandler(BaseHandler):
             )
         except LoginError:
             # Update the ratelimiter to say we failed (`can_do_action` doesn't raise).
-            self._failed_uia_attempts_ratelimiter.can_do_action(requester_user_id)
+            await self._failed_uia_attempts_ratelimiter.can_do_action(
+                requester,
+            )
             raise
 
         # find the completed login type
@@ -982,8 +986,8 @@ class AuthHandler(BaseHandler):
             # We also apply account rate limiting using the 3PID as a key, as
             # otherwise using 3PID bypasses the ratelimiting based on user ID.
             if ratelimit:
-                self._failed_login_attempts_ratelimiter.ratelimit(
-                    (medium, address), update=False
+                await self._failed_login_attempts_ratelimiter.ratelimit(
+                    None, (medium, address), update=False
                 )
 
             # Check for login providers that support 3pid login types
@@ -1016,8 +1020,8 @@ class AuthHandler(BaseHandler):
                 # this code path, which is fine as then the per-user ratelimit
                 # will kick in below.
                 if ratelimit:
-                    self._failed_login_attempts_ratelimiter.can_do_action(
-                        (medium, address)
+                    await self._failed_login_attempts_ratelimiter.can_do_action(
+                        None, (medium, address)
                     )
                 raise LoginError(403, "", errcode=Codes.FORBIDDEN)
 
@@ -1039,8 +1043,8 @@ class AuthHandler(BaseHandler):
 
         # Check if we've hit the failed ratelimit (but don't update it)
         if ratelimit:
-            self._failed_login_attempts_ratelimiter.ratelimit(
-                qualified_user_id.lower(), update=False
+            await self._failed_login_attempts_ratelimiter.ratelimit(
+                None, qualified_user_id.lower(), update=False
             )
 
         try:
@@ -1051,8 +1055,8 @@ class AuthHandler(BaseHandler):
             # exception and masking the LoginError. The actual ratelimiting
             # should have happened above.
             if ratelimit:
-                self._failed_login_attempts_ratelimiter.can_do_action(
-                    qualified_user_id.lower()
+                await self._failed_login_attempts_ratelimiter.can_do_action(
+                    None, qualified_user_id.lower()
                 )
             raise
 
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index eb547743be..5ee48be6ff 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -81,6 +81,7 @@ class DeviceMessageHandler:
             )
 
         self._ratelimiter = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.rc_key_requests.per_second,
             burst_count=hs.config.rc_key_requests.burst_count,
@@ -191,8 +192,8 @@ class DeviceMessageHandler:
             if (
                 message_type == EduTypes.RoomKeyRequest
                 and user_id != sender_user_id
-                and self._ratelimiter.can_do_action(
-                    (sender_user_id, requester.device_id)
+                and await self._ratelimiter.can_do_action(
+                    requester, (sender_user_id, requester.device_id)
                 )
             ):
                 continue
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 598a66f74c..3ebee38ebe 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1711,7 +1711,7 @@ class FederationHandler(BaseHandler):
         member_handler = self.hs.get_room_member_handler()
         # We don't rate limit based on room ID, as that should be done by
         # sending server.
-        member_handler.ratelimit_invite(None, event.state_key)
+        await member_handler.ratelimit_invite(None, None, event.state_key)
 
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 5f346f6d6d..d89fa5fb30 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -61,17 +61,19 @@ class IdentityHandler(BaseHandler):
 
         # Ratelimiters for `/requestToken` endpoints.
         self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
             burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
         )
         self._3pid_validation_ratelimiter_address = Ratelimiter(
+            store=self.store,
             clock=hs.get_clock(),
             rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
             burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
         )
 
-    def ratelimit_request_token_requests(
+    async def ratelimit_request_token_requests(
         self,
         request: SynapseRequest,
         medium: str,
@@ -85,8 +87,12 @@ class IdentityHandler(BaseHandler):
             address: The actual threepid ID, e.g. the phone number or email address
         """
 
-        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
-        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+        await self._3pid_validation_ratelimiter_ip.ratelimit(
+            None, (medium, request.getClientIP())
+        )
+        await self._3pid_validation_ratelimiter_address.ratelimit(
+            None, (medium, address)
+        )
 
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0fc2bf15d5..9701b76d0f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -204,7 +204,7 @@ class RegistrationHandler(BaseHandler):
         Raises:
             SynapseError if there was a problem registering.
         """
-        self.check_registration_ratelimit(address)
+        await self.check_registration_ratelimit(address)
 
         result = await self.spam_checker.check_registration_for_spam(
             threepid,
@@ -583,7 +583,7 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    def check_registration_ratelimit(self, address: Optional[str]) -> None:
+    async def check_registration_ratelimit(self, address: Optional[str]) -> None:
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
 
@@ -597,7 +597,7 @@ class RegistrationHandler(BaseHandler):
         if not address:
             return
 
-        self.ratelimiter.ratelimit(address)
+        await self.ratelimiter.ratelimit(None, address)
 
     async def register_with_store(
         self,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4d20ed8357..1cf12f3255 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -75,22 +75,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.allow_per_room_profiles = self.config.allow_per_room_profiles
 
         self._join_rate_limiter_local = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
         self._join_rate_limiter_remote = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
 
         self._invites_per_room_limiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
         )
         self._invites_per_user_limiter = Ratelimiter(
+            store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second,
             burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
@@ -159,15 +163,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     async def forget(self, user: UserID, room_id: str) -> None:
         raise NotImplementedError()
 
-    def ratelimit_invite(self, room_id: Optional[str], invitee_user_id: str):
+    async def ratelimit_invite(
+        self,
+        requester: Optional[Requester],
+        room_id: Optional[str],
+        invitee_user_id: str,
+    ):
         """Ratelimit invites by room and by target user.
 
         If room ID is missing then we just rate limit by target user.
         """
         if room_id:
-            self._invites_per_room_limiter.ratelimit(room_id)
+            await self._invites_per_room_limiter.ratelimit(requester, room_id)
 
-        self._invites_per_user_limiter.ratelimit(invitee_user_id)
+        await self._invites_per_user_limiter.ratelimit(requester, invitee_user_id)
 
     async def _local_membership_update(
         self,
@@ -237,7 +246,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 (
                     allowed,
                     time_allowed,
-                ) = self._join_rate_limiter_local.can_requester_do_action(requester)
+                ) = await self._join_rate_limiter_local.can_do_action(requester)
 
                 if not allowed:
                     raise LimitExceededError(
@@ -421,9 +430,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if effective_membership_state == Membership.INVITE:
             target_id = target.to_string()
             if ratelimit:
-                # Don't ratelimit application services.
-                if not requester.app_service or requester.app_service.is_rate_limited():
-                    self.ratelimit_invite(room_id, target_id)
+                await self.ratelimit_invite(requester, room_id, target_id)
 
             # block any attempts to invite the server notices mxid
             if target_id == self._server_notices_mxid:
@@ -534,7 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     (
                         allowed,
                         time_allowed,
-                    ) = self._join_rate_limiter_remote.can_requester_do_action(
+                    ) = await self._join_rate_limiter_remote.can_do_action(
                         requester,
                     )
 
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index d005f38767..73d7477854 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -77,7 +77,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        self.registration_handler.check_registration_ratelimit(content["address"])
+        await self.registration_handler.check_registration_ratelimit(content["address"])
 
         await self.registration_handler.register_with_store(
             user_id=user_id,
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index e4c352f572..3151e72d4f 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -74,11 +74,13 @@ class LoginRestServlet(RestServlet):
 
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_address.per_second,
             burst_count=self.hs.config.rc_login_address.burst_count,
         )
         self._account_ratelimiter = Ratelimiter(
+            store=hs.get_datastore(),
             clock=hs.get_clock(),
             rate_hz=self.hs.config.rc_login_account.per_second,
             burst_count=self.hs.config.rc_login_account.burst_count,
@@ -141,20 +143,22 @@ class LoginRestServlet(RestServlet):
                 appservice = self.auth.get_appservice_by_req(request)
 
                 if appservice.is_rate_limited():
-                    self._address_ratelimiter.ratelimit(request.getClientIP())
+                    await self._address_ratelimiter.ratelimit(
+                        None, request.getClientIP()
+                    )
 
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
-                self._address_ratelimiter.ratelimit(request.getClientIP())
+                await self._address_ratelimiter.ratelimit(None, request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -258,7 +262,7 @@ class LoginRestServlet(RestServlet):
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
         if ratelimit:
-            self._account_ratelimiter.ratelimit(user_id.lower())
+            await self._account_ratelimiter.ratelimit(None, user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index c2ba790bab..411fb57c47 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -103,7 +103,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
@@ -387,7 +389,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         if next_link:
             # Raise if the provided next_link value isn't valid
@@ -468,7 +472,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(
+        await self.identity_handler.ratelimit_request_token_requests(
             request, "msisdn", msisdn
         )
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 8f68d8dfc8..c212da0cb2 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,7 +126,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+        await self.identity_handler.ratelimit_request_token_requests(
+            request, "email", email
+        )
 
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
@@ -208,7 +210,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
-        self.identity_handler.ratelimit_request_token_requests(
+        await self.identity_handler.ratelimit_request_token_requests(
             request, "msisdn", msisdn
         )
 
@@ -406,7 +408,7 @@ class RegisterRestServlet(RestServlet):
 
         client_addr = request.getClientIP()
 
-        self.ratelimiter.ratelimit(client_addr, update=False)
+        await self.ratelimiter.ratelimit(None, client_addr, update=False)
 
         kind = b"user"
         if b"kind" in request.args:
diff --git a/synapse/server.py b/synapse/server.py
index e85b9391fa..e42f7b1a18 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -329,6 +329,7 @@ class HomeServer(metaclass=abc.ABCMeta):
     @cache_in_self
     def get_registration_ratelimiter(self) -> Ratelimiter:
         return Ratelimiter(
+            store=self.get_datastore(),
             clock=self.get_clock(),
             rate_hz=self.config.rc_registration.per_second,
             burst_count=self.config.rc_registration.burst_count,
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 483418192c..fa96ba07a5 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -5,38 +5,25 @@ from synapse.types import create_requester
 from tests import unittest
 
 
-class TestRatelimiter(unittest.TestCase):
+class TestRatelimiter(unittest.HomeserverTestCase):
     def test_allowed_via_can_do_action(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=0)
-        self.assertTrue(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=5)
-        self.assertFalse(allowed)
-        self.assertEquals(10.0, time_allowed)
-
-        allowed, time_allowed = limiter.can_do_action(key="test_id", _time_now_s=10)
-        self.assertTrue(allowed)
-        self.assertEquals(20.0, time_allowed)
-
-    def test_allowed_user_via_can_requester_do_action(self):
-        user_requester = create_requester("@user:example.com")
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            user_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id", _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -51,21 +38,23 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertFalse(allowed)
         self.assertEquals(10.0, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(20.0, time_allowed)
@@ -80,73 +69,89 @@ class TestRatelimiter(unittest.TestCase):
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=0
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=0)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=5
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=5)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
-        allowed, time_allowed = limiter.can_requester_do_action(
-            as_requester, _time_now_s=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(as_requester, _time_now_s=10)
         )
         self.assertTrue(allowed)
         self.assertEquals(-1, time_allowed)
 
     def test_allowed_via_ratelimit(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=0)
+        self.get_success_or_raise(limiter.ratelimit(None, key="test_id", _time_now_s=0))
 
         # Should raise
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key="test_id", _time_now_s=5)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key="test_id", _time_now_s=5)
+            )
         self.assertEqual(context.exception.retry_after_ms, 5000)
 
         # Shouldn't raise
-        limiter.ratelimit(key="test_id", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key="test_id", _time_now_s=10)
+        )
 
     def test_allowed_via_can_do_action_and_overriding_parameters(self):
         """Test that we can override options of can_do_action that would otherwise fail
         an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=0,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=0,
+            )
         )
         self.assertTrue(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # Second attempt, 1s later, will fail
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",),
-            _time_now_s=1,
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(
+                None,
+                ("test_id",),
+                _time_now_s=1,
+            )
         )
         self.assertFalse(allowed)
         self.assertEqual(10.0, time_allowed)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, rate_hz=10.0
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, rate_hz=10.0)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.1, time_allowed)
 
         # Similarly if we allow a burst of 10 actions
-        allowed, time_allowed = limiter.can_do_action(
-            ("test_id",), _time_now_s=1, burst_count=10
+        allowed, time_allowed = self.get_success_or_raise(
+            limiter.can_do_action(None, ("test_id",), _time_now_s=1, burst_count=10)
         )
         self.assertTrue(allowed)
         self.assertEqual(1.0, time_allowed)
@@ -156,29 +161,72 @@ class TestRatelimiter(unittest.TestCase):
         fail an action
         """
         # Create a Ratelimiter with a very low allowed rate_hz and burst_count
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
 
         # First attempt should be allowed
-        limiter.ratelimit(key=("test_id",), _time_now_s=0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=0)
+        )
 
         # Second attempt, 1s later, will fail
         with self.assertRaises(LimitExceededError) as context:
-            limiter.ratelimit(key=("test_id",), _time_now_s=1)
+            self.get_success_or_raise(
+                limiter.ratelimit(None, key=("test_id",), _time_now_s=1)
+            )
         self.assertEqual(context.exception.retry_after_ms, 9000)
 
         # But, if we allow 10 actions/sec for this request, we should be allowed
         # to continue.
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, rate_hz=10.0)
+        )
 
         # Similarly if we allow a burst of 10 actions
-        limiter.ratelimit(key=("test_id",), _time_now_s=1, burst_count=10)
+        self.get_success_or_raise(
+            limiter.ratelimit(None, key=("test_id",), _time_now_s=1, burst_count=10)
+        )
 
     def test_pruning(self):
-        limiter = Ratelimiter(clock=None, rate_hz=0.1, burst_count=1)
-        limiter.can_do_action(key="test_id_1", _time_now_s=0)
+        limiter = Ratelimiter(
+            store=self.hs.get_datastore(), clock=None, rate_hz=0.1, burst_count=1
+        )
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
+        )
 
         self.assertIn("test_id_1", limiter.actions)
 
-        limiter.can_do_action(key="test_id_2", _time_now_s=10)
+        self.get_success_or_raise(
+            limiter.can_do_action(None, key="test_id_2", _time_now_s=10)
+        )
 
         self.assertNotIn("test_id_1", limiter.actions)
+
+    def test_db_user_override(self):
+        """Test that users that have ratelimiting disabled in the DB aren't
+        ratelimited.
+        """
+        store = self.hs.get_datastore()
+
+        user_id = "@user:test"
+        requester = create_requester(user_id)
+
+        self.get_success(
+            store.db_pool.simple_insert(
+                table="ratelimit_override",
+                values={
+                    "user_id": user_id,
+                    "messages_per_second": None,
+                    "burst_count": None,
+                },
+                desc="test_db_user_override",
+            )
+        )
+
+        limiter = Ratelimiter(store=store, clock=None, rate_hz=0.1, burst_count=1)
+
+        # Shouldn't raise
+        for _ in range(20):
+            self.get_success_or_raise(limiter.ratelimit(requester, _time_now_s=0))