summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/v1/test_events.py4
-rw-r--r--tests/rest/client/v1/test_rooms.py6
-rw-r--r--tests/rest/client/v1/test_typing.py4
-rw-r--r--tests/rest/client/v2_alpha/test_register.py48
4 files changed, 55 insertions, 7 deletions
diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py
index 483bebc832..36d8547275 100644
--- a/tests/rest/client/v1/test_events.py
+++ b/tests/rest/client/v1/test_events.py
@@ -40,10 +40,10 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
         config.auto_join_rooms = []
 
         hs = self.setup_test_homeserver(
-            config=config, ratelimiter=NonCallableMock(spec_set=["send_message"])
+            config=config, ratelimiter=NonCallableMock(spec_set=["can_do_action"])
         )
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index a824be9a62..015c144248 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -41,10 +41,10 @@ class RoomBase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
         self.ratelimiter = self.hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         self.hs.get_federation_handler = Mock(return_value=Mock())
 
@@ -96,7 +96,7 @@ class RoomPermissionsTestCase(RoomBase):
         # auth as user_id now
         self.helper.auth_user_id = self.user_id
 
-    def test_send_message(self):
+    def test_can_do_action(self):
         msg_content = b'{"msgtype":"m.text","body":"hello"}'
 
         seq = iter(range(100))
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 0ad814c5e5..30fb77bac8 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -42,13 +42,13 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
             "red",
             http_client=None,
             federation_client=Mock(),
-            ratelimiter=NonCallableMock(spec_set=["send_message"]),
+            ratelimiter=NonCallableMock(spec_set=["can_do_action"]),
         )
 
         self.event_source = hs.get_event_sources().sources["typing"]
 
         self.ratelimiter = hs.get_ratelimiter()
-        self.ratelimiter.send_message.return_value = (True, 0)
+        self.ratelimiter.can_do_action.return_value = (True, 0)
 
         hs.get_handlers().federation_handler = Mock()
 
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 906b348d3e..3600434858 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -130,3 +130,51 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(channel.result["code"], b"403", channel.result)
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
+
+    def test_POST_ratelimiting_guest(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            url = self.url + b"?kind=guest"
+            request, channel = self.make_request(b"POST", url, b"{}")
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_POST_ratelimiting(self):
+        self.hs.config.rc_registration_request_burst_count = 5
+
+        for i in range(0, 6):
+            params = {
+                "username": "kermit" + str(i),
+                "password": "monkey",
+                "device_id": "frogfone",
+                "auth": {"type": LoginType.DUMMY},
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", self.url, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)