summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4735.feature1
-rw-r--r--docs/sample_config.yaml11
-rw-r--r--synapse/api/ratelimiting.py31
-rw-r--r--synapse/config/registration.py18
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/register.py39
-rw-r--r--synapse/replication/http/register.py8
-rw-r--r--synapse/rest/client/v2_alpha/register.py33
-rw-r--r--tests/api/test_ratelimiting.py20
-rw-r--r--tests/handlers/test_profile.py4
-rw-r--r--tests/replication/slave/storage/_base.py4
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py48
-rw-r--r--tests/test_mau.py3
-rw-r--r--tests/utils.py2
17 files changed, 186 insertions, 54 deletions
diff --git a/changelog.d/4735.feature b/changelog.d/4735.feature
new file mode 100644
index 0000000000..a4c0b196f6
--- /dev/null
+++ b/changelog.d/4735.feature
@@ -0,0 +1 @@
+Add configurable rate limiting to the /register endpoint.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index 7cf58d2182..e0140003fd 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -657,6 +657,17 @@ trusted_third_party_id_servers:
 #
 autocreate_auto_join_rooms: true
 
+# Number of registration requests a client can send per second.
+# Defaults to 1/minute (0.17).
+#
+#rc_registration_requests_per_second: 0.17
+
+# Number of registration requests a client can send before being
+# throttled.
+# Defaults to 3.
+#
+#rc_registration_request_burst_count: 3.0
+
 
 ## Metrics ###
 
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index 3bb5b3da37..ad68079eeb 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -23,12 +23,13 @@ class Ratelimiter(object):
     def __init__(self):
         self.message_counts = collections.OrderedDict()
 
-    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count, update=True):
-        """Can the user send a message?
+    def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
+        """Can the entity (e.g. user or IP address) perform the action?
         Args:
-            user_id: The user sending a message.
+            key: The key we should use when rate limiting. Can be a user ID
+                (when sending events), an IP address, etc.
             time_now_s: The time now.
-            msg_rate_hz: The long term number of messages a user can send in a
+            rate_hz: The long term number of messages a user can send in a
                 second.
             burst_count: How many messages the user can send before being
                 limited.
@@ -41,10 +42,10 @@ class Ratelimiter(object):
         """
         self.prune_message_counts(time_now_s)
         message_count, time_start, _ignored = self.message_counts.get(
-            user_id, (0., time_now_s, None),
+            key, (0., time_now_s, None),
         )
         time_delta = time_now_s - time_start
-        sent_count = message_count - time_delta * msg_rate_hz
+        sent_count = message_count - time_delta * rate_hz
         if sent_count < 0:
             allowed = True
             time_start = time_now_s
@@ -56,13 +57,13 @@ class Ratelimiter(object):
             message_count += 1
 
         if update:
-            self.message_counts[user_id] = (
-                message_count, time_start, msg_rate_hz
+            self.message_counts[key] = (
+                message_count, time_start, rate_hz
             )
 
-        if msg_rate_hz > 0:
+        if rate_hz > 0:
             time_allowed = (
-                time_start + (message_count - burst_count + 1) / msg_rate_hz
+                time_start + (message_count - burst_count + 1) / rate_hz
             )
             if time_allowed < time_now_s:
                 time_allowed = time_now_s
@@ -72,12 +73,12 @@ class Ratelimiter(object):
         return allowed, time_allowed
 
     def prune_message_counts(self, time_now_s):
-        for user_id in list(self.message_counts.keys()):
-            message_count, time_start, msg_rate_hz = (
-                self.message_counts[user_id]
+        for key in list(self.message_counts.keys()):
+            message_count, time_start, rate_hz = (
+                self.message_counts[key]
             )
             time_delta = time_now_s - time_start
-            if message_count - time_delta * msg_rate_hz > 0:
+            if message_count - time_delta * rate_hz > 0:
                 break
             else:
-                del self.message_counts[user_id]
+                del self.message_counts[key]
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index 2881482f96..d32f6fff73 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -54,6 +54,13 @@ class RegistrationConfig(Config):
             config.get("disable_msisdn_registration", False)
         )
 
+        self.rc_registration_requests_per_second = config.get(
+            "rc_registration_requests_per_second", 0.17,
+        )
+        self.rc_registration_request_burst_count = config.get(
+            "rc_registration_request_burst_count", 3,
+        )
+
     def default_config(self, generate_secrets=False, **kwargs):
         if generate_secrets:
             registration_shared_secret = 'registration_shared_secret: "%s"' % (
@@ -140,6 +147,17 @@ class RegistrationConfig(Config):
         # users cannot be auto-joined since they do not exist.
         #
         autocreate_auto_join_rooms: true
+
+        # Number of registration requests a client can send per second.
+        # Defaults to 1/minute (0.17).
+        #
+        #rc_registration_requests_per_second: 0.17
+
+        # Number of registration requests a client can send before being
+        # throttled.
+        # Defaults to 3.
+        #
+        #rc_registration_request_burst_count: 3.0
         """ % locals()
 
     def add_arguments(self, parser):
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 594754cfd8..d8d86d6ff3 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -93,9 +93,9 @@ class BaseHandler(object):
             messages_per_second = self.hs.config.rc_messages_per_second
             burst_count = self.hs.config.rc_message_burst_count
 
-        allowed, time_allowed = self.ratelimiter.send_message(
+        allowed, time_allowed = self.ratelimiter.can_do_action(
             user_id, time_now,
-            msg_rate_hz=messages_per_second,
+            rate_hz=messages_per_second,
             burst_count=burst_count,
             update=update,
         )
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c0e06929bd..47d5e276f8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -24,6 +24,7 @@ from synapse.api.errors import (
     AuthError,
     Codes,
     InvalidCaptchaError,
+    LimitExceededError,
     RegistrationError,
     SynapseError,
 )
@@ -60,6 +61,7 @@ class RegistrationHandler(BaseHandler):
         self.user_directory_handler = hs.get_user_directory_handler()
         self.captcha_client = CaptchaServerHttpClient(hs)
         self.identity_handler = self.hs.get_handlers().identity_handler
+        self.ratelimiter = hs.get_ratelimiter()
 
         self._next_generated_user_id = None
 
@@ -149,6 +151,7 @@ class RegistrationHandler(BaseHandler):
         threepid=None,
         user_type=None,
         default_display_name=None,
+        address=None,
     ):
         """Registers a new client on the server.
 
@@ -167,6 +170,7 @@ class RegistrationHandler(BaseHandler):
               api.constants.UserTypes, or None for a normal user.
             default_display_name (unicode|None): if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
+            address (str|None): the IP address used to perform the regitration.
         Returns:
             A tuple of (user_id, access_token).
         Raises:
@@ -206,7 +210,7 @@ class RegistrationHandler(BaseHandler):
             token = None
             if generate_token:
                 token = self.macaroon_gen.generate_access_token(user_id)
-            yield self._register_with_store(
+            yield self.register_with_store(
                 user_id=user_id,
                 token=token,
                 password_hash=password_hash,
@@ -215,6 +219,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_displayname=default_display_name,
                 admin=admin,
                 user_type=user_type,
+                address=address,
             )
 
             if self.hs.config.user_directory_search_all_users:
@@ -238,12 +243,13 @@ class RegistrationHandler(BaseHandler):
                 if default_display_name is None:
                     default_display_name = localpart
                 try:
-                    yield self._register_with_store(
+                    yield self.register_with_store(
                         user_id=user_id,
                         token=token,
                         password_hash=password_hash,
                         make_guest=make_guest,
                         create_profile_with_displayname=default_display_name,
+                        address=address,
                     )
                 except SynapseError:
                     # if user id is taken, just generate another
@@ -337,7 +343,7 @@ class RegistrationHandler(BaseHandler):
             user_id, allowed_appservice=service
         )
 
-        yield self._register_with_store(
+        yield self.register_with_store(
             user_id=user_id,
             password_hash="",
             appservice_id=service_id,
@@ -513,7 +519,7 @@ class RegistrationHandler(BaseHandler):
         token = self.macaroon_gen.generate_access_token(user_id)
 
         if need_register:
-            yield self._register_with_store(
+            yield self.register_with_store(
                 user_id=user_id,
                 token=token,
                 password_hash=password_hash,
@@ -590,10 +596,10 @@ class RegistrationHandler(BaseHandler):
             ratelimit=False,
         )
 
-    def _register_with_store(self, user_id, token=None, password_hash=None,
-                             was_guest=False, make_guest=False, appservice_id=None,
-                             create_profile_with_displayname=None, admin=False,
-                             user_type=None):
+    def register_with_store(self, user_id, token=None, password_hash=None,
+                            was_guest=False, make_guest=False, appservice_id=None,
+                            create_profile_with_displayname=None, admin=False,
+                            user_type=None, address=None):
         """Register user in the datastore.
 
         Args:
@@ -612,10 +618,26 @@ class RegistrationHandler(BaseHandler):
             admin (boolean): is an admin user?
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
+            address (str|None): the IP address used to perform the regitration.
 
         Returns:
             Deferred
         """
+        # Don't rate limit for app services
+        if appservice_id is None and address is not None:
+            time_now = self.clock.time()
+
+            allowed, time_allowed = self.ratelimiter.can_do_action(
+                address, time_now_s=time_now,
+                rate_hz=self.hs.config.rc_registration_requests_per_second,
+                burst_count=self.hs.config.rc_registration_request_burst_count,
+            )
+
+            if not allowed:
+                raise LimitExceededError(
+                    retry_after_ms=int(1000 * (time_allowed - time_now)),
+                )
+
         if self.hs.config.worker_app:
             return self._register_client(
                 user_id=user_id,
@@ -627,6 +649,7 @@ class RegistrationHandler(BaseHandler):
                 create_profile_with_displayname=create_profile_with_displayname,
                 admin=admin,
                 user_type=user_type,
+                address=address,
             )
         else:
             return self.store.register(
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 1d27c9221f..912a5ac341 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -33,11 +33,12 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
     def __init__(self, hs):
         super(ReplicationRegisterServlet, self).__init__(hs)
         self.store = hs.get_datastore()
+        self.registration_handler = hs.get_registration_handler()
 
     @staticmethod
     def _serialize_payload(
         user_id, token, password_hash, was_guest, make_guest, appservice_id,
-        create_profile_with_displayname, admin, user_type,
+        create_profile_with_displayname, admin, user_type, address,
     ):
         """
         Args:
@@ -56,6 +57,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             admin (boolean): is an admin user?
             user_type (str|None): type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
+            address (str|None): the IP address used to perform the regitration.
         """
         return {
             "token": token,
@@ -66,13 +68,14 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             "create_profile_with_displayname": create_profile_with_displayname,
             "admin": admin,
             "user_type": user_type,
+            "address": address,
         }
 
     @defer.inlineCallbacks
     def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        yield self.store.register(
+        yield self.registration_handler.register_with_store(
             user_id=user_id,
             token=content["token"],
             password_hash=content["password_hash"],
@@ -82,6 +85,7 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             create_profile_with_displayname=content["create_profile_with_displayname"],
             admin=content["admin"],
             user_type=content["user_type"],
+            address=content["address"]
         )
 
         defer.returnValue((200, {}))
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 94cbba4303..b7f354570c 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -25,7 +25,12 @@ from twisted.internet import defer
 import synapse
 import synapse.types
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
+from synapse.api.errors import (
+    Codes,
+    LimitExceededError,
+    SynapseError,
+    UnrecognizedRequestError,
+)
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import (
     RestServlet,
@@ -191,18 +196,36 @@ class RegisterRestServlet(RestServlet):
         self.identity_handler = hs.get_handlers().identity_handler
         self.room_member_handler = hs.get_room_member_handler()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self.ratelimiter = hs.get_ratelimiter()
+        self.clock = hs.get_clock()
 
     @interactive_auth_handler
     @defer.inlineCallbacks
     def on_POST(self, request):
         body = parse_json_object_from_request(request)
 
+        client_addr = request.getClientIP()
+
+        time_now = self.clock.time()
+
+        allowed, time_allowed = self.ratelimiter.can_do_action(
+            client_addr, time_now_s=time_now,
+            rate_hz=self.hs.config.rc_registration_requests_per_second,
+            burst_count=self.hs.config.rc_registration_request_burst_count,
+            update=False,
+        )
+
+        if not allowed:
+            raise LimitExceededError(
+                retry_after_ms=int(1000 * (time_allowed - time_now)),
+            )
+
         kind = b"user"
         if b"kind" in request.args:
             kind = request.args[b"kind"][0]
 
         if kind == b"guest":
-            ret = yield self._do_guest_registration(body)
+            ret = yield self._do_guest_registration(body, address=client_addr)
             defer.returnValue(ret)
             return
         elif kind != b"user":
@@ -411,6 +434,7 @@ class RegisterRestServlet(RestServlet):
                 guest_access_token=guest_access_token,
                 generate_token=False,
                 threepid=threepid,
+                address=client_addr,
             )
             # Necessary due to auth checks prior to the threepid being
             # written to the db
@@ -522,12 +546,13 @@ class RegisterRestServlet(RestServlet):
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _do_guest_registration(self, params):
+    def _do_guest_registration(self, params, address=None):
         if not self.hs.config.allow_guest_access:
             raise SynapseError(403, "Guest access is disabled")
         user_id, _ = yield self.registration_handler.register(
             generate_token=False,
-            make_guest=True
+            make_guest=True,
+            address=address,
         )
 
         # we don't allow guests to specify their own device_id, because
diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py
index 8933fe3b72..30a255d441 100644
--- a/tests/api/test_ratelimiting.py
+++ b/tests/api/test_ratelimiting.py
@@ -6,34 +6,34 @@ from tests import unittest
 class TestRatelimiter(unittest.TestCase):
     def test_allowed(self):
         limiter = Ratelimiter()
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=0, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=0, rate_hz=0.1, burst_count=1
         )
         self.assertTrue(allowed)
         self.assertEquals(10., time_allowed)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=5, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=5, rate_hz=0.1, burst_count=1
         )
         self.assertFalse(allowed)
         self.assertEquals(10., time_allowed)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id", time_now_s=10, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id", time_now_s=10, rate_hz=0.1, burst_count=1
         )
         self.assertTrue(allowed)
         self.assertEquals(20., time_allowed)
 
     def test_pruning(self):
         limiter = Ratelimiter()
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id_1", time_now_s=0, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id_1", time_now_s=0, rate_hz=0.1, burst_count=1
         )
 
         self.assertIn("test_id_1", limiter.message_counts)
 
-        allowed, time_allowed = limiter.send_message(
-            user_id="test_id_2", time_now_s=10, msg_rate_hz=0.1, burst_count=1
+        allowed, time_allowed = limiter.can_do_action(
+            key="test_id_2", time_now_s=10, rate_hz=0.1, burst_count=1
         )
 
         self.assertNotIn("test_id_1", limiter.message_counts)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 80da1c8954..d60c124eec 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -55,11 +55,11 @@ class ProfileTestCase(unittest.TestCase):
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         self.store = hs.get_datastore()
 
diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 9e9fbbfe93..524af4f8d1 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -31,10 +31,10 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
         hs = self.setup_test_homeserver(
             "blue",
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
-        hs.get_ratelimiter().send_message.return_value = (True, 0)
+        hs.get_ratelimiter().can_do_action.return_value = (True, 0)
 
         return hs
 
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 483bebc832..36d8547275 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         config.auto_join_rooms = []
 
         hs = self.setup_test_homeserver(
-            config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
+            config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
         )
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index a824be9a62..015c144248 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
         self.ratelimiter = self.hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         self.hs.get_federation_handler = Mock(return_value=Mock())
 
@@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
         # auth as user_id now
         self.helper.auth_user_id = self.user_id
 
-    def test_send_message(self):
+    def test_can_do_action(self):
         msg_content = b'{"msgtype":"m.text","body":"hello"}'
 
         seq = iter(range(100))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 0ad814c5e5..30fb77bac8 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
 
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 906b348d3e..3600434858 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+    def test_POST_ratelimiting_guest(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            url = self.url + b"?kind=guest"
+            request, channel = self.make_request(b"POST", url, b"{}")
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_POST_ratelimiting(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            params = {
+                "username": "kermit" + str(i),
+                "password": "monkey",
+                "device_id": "frogfone",
+                "auth": {"type": LoginType.DUMMY},
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", self.url, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/test_mau.py b/tests/test_mau.py
index 04f95c942f..00be1a8c21 100644
--- a/tests/test_mau.py
+++ b/tests/test_mau.py
@@ -17,7 +17,7 @@
 
 import json
 
-from mock import Mock, NonCallableMock
+from mock import Mock
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
@@ -36,7 +36,6 @@ class TestMauLimit(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
         )
 
         self.store = self.hs.get_datastore()
diff --git a/tests/utils.py b/tests/utils.py
index ee272157aa..e4c42f9fa8 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -150,6 +150,8 @@ def default_config(name):
     config.admin_contact = None
     config.rc_messages_per_second = 10000
     config.rc_message_burst_count = 10000
+    config.rc_registration_request_burst_count = 3.0
+    config.rc_registration_requests_per_second = 0.17
     config.saml2_enabled = False
     config.public_baseurl = None
     config.default_identity_server = None