summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorBrendan Abolivier <contact@brendanabolivier.com>2019-03-15 17:46:16 +0000
committerGitHub <noreply@github.com>2019-03-15 17:46:16 +0000
commit899e523d6d92dfbc17dce81eb36f63053e447a97 (patch)
tree5a8e2a7b2638cdc06a6dd4c8736c828c25ba47b9 /tests
parentMerge pull request #4855 from matrix-org/rav/refactor_transaction_queue (diff)
downloadsynapse-899e523d6d92dfbc17dce81eb36f63053e447a97.tar.xz
Add ratelimiting on login (#4821)
Add two ratelimiters on login (per-IP address and per-userID).
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/v1/test_login.py118
-rw-r--r--tests/rest/client/v2_alpha/test_register.py6
-rw-r--r--tests/utils.py8
3 files changed, 128 insertions, 4 deletions
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
new file mode 100644
index 0000000000..4035f76cca
--- /dev/null
+++ b/tests/rest/client/v1/test_login.py
@@ -0,0 +1,118 @@
+import json
+
+from synapse.rest.client.v1 import admin, login
+
+from tests import unittest
+
+LOGIN_URL = b"/_matrix/client/r0/login"
+
+
+class LoginRestServletTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor, clock):
+
+        self.hs = self.setup_test_homeserver()
+        self.hs.config.enable_registration = True
+        self.hs.config.registrations_require_3pid = []
+        self.hs.config.auto_join_rooms = []
+        self.hs.config.enable_registration_captcha = False
+
+        return self.hs
+
+    def test_POST_ratelimiting_per_address(self):
+        self.hs.config.rc_login_address.burst_count = 5
+        self.hs.config.rc_login_address.per_second = 0.17
+
+        # Create different users so we're sure not to be bothered by the per-user
+        # ratelimiter.
+        for i in range(0, 6):
+            self.register_user("kermit" + str(i), "monkey")
+
+        for i in range(0, 6):
+            params = {
+                "type": "m.login.password",
+                "identifier": {
+                    "type": "m.id.user",
+                    "user": "kermit" + str(i),
+                },
+                "password": "monkey",
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
+        # than 1min.
+        self.assertTrue(retry_after_ms < 6000)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        params = {
+            "type": "m.login.password",
+            "identifier": {
+                "type": "m.id.user",
+                "user": "kermit" + str(i),
+            },
+            "password": "monkey",
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
+
+    def test_POST_ratelimiting_per_account(self):
+        self.hs.config.rc_login_account.burst_count = 5
+        self.hs.config.rc_login_account.per_second = 0.17
+
+        self.register_user("kermit", "monkey")
+
+        for i in range(0, 6):
+            params = {
+                "type": "m.login.password",
+                "identifier": {
+                    "type": "m.id.user",
+                    "user": "kermit",
+                },
+                "password": "monkey",
+            }
+            request_data = json.dumps(params)
+            request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
+            self.render(request)
+
+            if i == 5:
+                self.assertEquals(channel.result["code"], b"429", channel.result)
+                retry_after_ms = int(channel.json_body["retry_after_ms"])
+            else:
+                self.assertEquals(channel.result["code"], b"200", channel.result)
+
+        # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
+        # than 1min.
+        self.assertTrue(retry_after_ms < 6000)
+
+        self.reactor.advance(retry_after_ms / 1000.)
+
+        params = {
+            "type": "m.login.password",
+            "identifier": {
+                "type": "m.id.user",
+                "user": "kermit",
+            },
+            "password": "monkey",
+        }
+        request_data = json.dumps(params)
+        request, channel = self.make_request(b"POST", LOGIN_URL, params)
+        self.render(request)
+
+        self.assertEquals(channel.result["code"], b"200", channel.result)
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 3600434858..8fb525d3bf 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -132,7 +132,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.json_body["error"], "Guest access is disabled")
 
     def test_POST_ratelimiting_guest(self):
-        self.hs.config.rc_registration_request_burst_count = 5
+        self.hs.config.rc_registration.burst_count = 5
+        self.hs.config.rc_registration.per_second = 0.17
 
         for i in range(0, 6):
             url = self.url + b"?kind=guest"
@@ -153,7 +154,8 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         self.assertEquals(channel.result["code"], b"200", channel.result)
 
     def test_POST_ratelimiting(self):
-        self.hs.config.rc_registration_request_burst_count = 5
+        self.hs.config.rc_registration.burst_count = 5
+        self.hs.config.rc_registration.per_second = 0.17
 
         for i in range(0, 6):
             params = {
diff --git a/tests/utils.py b/tests/utils.py
index 03b5a05b22..a412736492 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -151,8 +151,12 @@ def default_config(name):
     config.admin_contact = None
     config.rc_messages_per_second = 10000
     config.rc_message_burst_count = 10000
-    config.rc_registration_request_burst_count = 3.0
-    config.rc_registration_requests_per_second = 0.17
+    config.rc_registration.per_second = 10000
+    config.rc_registration.burst_count = 10000
+    config.rc_login_address.per_second = 10000
+    config.rc_login_address.burst_count = 10000
+    config.rc_login_account.per_second = 10000
+    config.rc_login_account.burst_count = 10000
     config.saml2_enabled = False
     config.public_baseurl = None
     config.default_identity_server = None