summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9238.feature1
-rw-r--r--docs/sample_config.yaml6
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/ratelimiting.py13
-rw-r--r--synapse/handlers/identity.py28
-rw-r--r--synapse/rest/client/v2_alpha/account.py12
-rw-r--r--synapse/rest/client/v2_alpha/register.py6
-rw-r--r--tests/rest/client/v2_alpha/test_account.py90
-rw-r--r--tests/server.py9
-rw-r--r--tests/unittest.py5
-rw-r--r--tests/utils.py1
11 files changed, 159 insertions, 14 deletions
diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature
new file mode 100644
index 0000000000..143a3e14f5
--- /dev/null
+++ b/changelog.d/9238.feature
@@ -0,0 +1 @@
+Add ratelimited to 3PID `/requestToken` API.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index c2ccd68f3a..e5b6268087 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #     users are joining rooms the server is already in (this is cheap) vs
 #     "remote" for when users are trying to join rooms not on the server (which
 #     can be more expensive)
+#   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
 #
 # The defaults are as shown below.
 #
@@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config"
 #  remote:
 #    per_second: 0.01
 #    burst_count: 3
-
+#
+#rc_3pid_validation:
+#  per_second: 0.003
+#  burst_count: 5
 
 # Ratelimiting settings for incoming federation
 #
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 7ed07a801d..70025b5d60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -54,7 +54,7 @@ class RootConfig:
     tls: tls.TlsConfig
     database: database.DatabaseConfig
     logging: logger.LoggingConfig
-    ratelimit: ratelimiting.RatelimitConfig
+    ratelimiting: ratelimiting.RatelimitConfig
     media: repository.ContentRepositoryConfig
     captcha: captcha.CaptchaConfig
     voip: voip.VoipConfig
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 14b8836197..76f382527d 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -24,7 +24,7 @@ class RateLimitConfig:
         defaults={"per_second": 0.17, "burst_count": 3.0},
     ):
         self.per_second = config.get("per_second", defaults["per_second"])
-        self.burst_count = config.get("burst_count", defaults["burst_count"])
+        self.burst_count = int(config.get("burst_count", defaults["burst_count"]))
 
 
 class FederationRateLimitConfig:
@@ -102,6 +102,11 @@ class RatelimitConfig(Config):
             defaults={"per_second": 0.01, "burst_count": 3},
         )
 
+        self.rc_3pid_validation = RateLimitConfig(
+            config.get("rc_3pid_validation") or {},
+            defaults={"per_second": 0.003, "burst_count": 5},
+        )
+
     def generate_config_section(self, **kwargs):
         return """\
         ## Ratelimiting ##
@@ -131,6 +136,7 @@ class RatelimitConfig(Config):
         #     users are joining rooms the server is already in (this is cheap) vs
         #     "remote" for when users are trying to join rooms not on the server (which
         #     can be more expensive)
+        #   - one for ratelimiting how often a user or IP can attempt to validate a 3PID.
         #
         # The defaults are as shown below.
         #
@@ -164,7 +170,10 @@ class RatelimitConfig(Config):
         #  remote:
         #    per_second: 0.01
         #    burst_count: 3
-
+        #
+        #rc_3pid_validation:
+        #  per_second: 0.003
+        #  burst_count: 5
 
         # Ratelimiting settings for incoming federation
         #
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index f61844d688..4f7137539b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -27,9 +27,11 @@ from synapse.api.errors import (
     HttpResponseException,
     SynapseError,
 )
+from synapse.api.ratelimiting import Ratelimiter
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
+from synapse.http.site import SynapseRequest
 from synapse.types import JsonDict, Requester
 from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
@@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler):
 
         self._web_client_location = hs.config.invite_client_location
 
+        # Ratelimiters for `/requestToken` endpoints.
+        self._3pid_validation_ratelimiter_ip = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+        self._3pid_validation_ratelimiter_address = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second,
+            burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
+        )
+
+    def ratelimit_request_token_requests(
+        self, request: SynapseRequest, medium: str, address: str,
+    ):
+        """Used to ratelimit requests to `/requestToken` by IP and address.
+
+        Args:
+            request: The associated request
+            medium: The type of threepid, e.g. "msisdn" or "email"
+            address: The actual threepid ID, e.g. the phone number or email address
+        """
+
+        self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP()))
+        self._3pid_validation_ratelimiter_address.ratelimit((medium, address))
+
     async def threepid_from_creds(
         self, id_server: str, creds: Dict[str, str]
     ) -> Optional[JsonDict]:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..a84a2fb385 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..10e1891174 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index cb87b80e33..177dc476da 100644
--- a/tests/rest/client/v2_alpha/test_account.py
+++ b/tests/rest/client/v2_alpha/test_account.py
@@ -24,7 +24,7 @@ import pkg_resources
 
 import synapse.rest.admin
 from synapse.api.constants import LoginType, Membership
-from synapse.api.errors import Codes
+from synapse.api.errors import Codes, HttpResponseException
 from synapse.rest.client.v1 import login, room
 from synapse.rest.client.v2_alpha import account, register
 from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource
@@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
         # Assert we can't log in with the old password
         self.attempt_wrong_password_login("kermit", old_password)
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_email(self):
+        """Test that we ratelimit /requestToken for the same email.
+        """
+        old_password = "monkey"
+        new_password = "kangeroo"
+
+        user_id = self.register_user("kermit", old_password)
+        self.login("kermit", old_password)
+
+        email = "test1@example.com"
+
+        # Add a threepid
+        self.get_success(
+            self.store.user_add_threepid(
+                user_id=user_id,
+                medium="email",
+                address=email,
+                validated_at=0,
+                added_at=0,
+            )
+        )
+
+        def reset(ip):
+            client_secret = "foobar"
+            session_id = self._request_token(email, client_secret, ip)
+
+            self.assertEquals(len(self.email_attempts), 1)
+            link = self._get_link_from_email()
+
+            self._validate_token(link)
+
+            self._reset_password(new_password, session_id, client_secret)
+
+            self.email_attempts.clear()
+
+        # We expect to be able to make three requests before getting rate
+        # limited.
+        #
+        # We change IPs to ensure that we're not being ratelimited due to the
+        # same IP
+        reset("127.0.0.1")
+        reset("127.0.0.2")
+        reset("127.0.0.3")
+
+        with self.assertRaises(HttpResponseException) as cm:
+            reset("127.0.0.4")
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_basic_password_reset_canonicalise_email(self):
         """Test basic password reset flow
         Request password reset with different spelling
@@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
 
         self.assertIsNotNone(session_id)
 
-    def _request_token(self, email, client_secret):
+    def _request_token(self, email, client_secret, ip="127.0.0.1"):
         channel = self.make_request(
             "POST",
             b"account/password/email/requestToken",
             {"client_secret": client_secret, "email": email, "send_attempt": 1},
+            client_ip=ip,
         )
-        self.assertEquals(200, channel.code, channel.result)
+
+        if channel.code != 200:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body["sid"]
 
@@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_address_trim(self):
         self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar"))
 
+    @override_config({"rc_3pid_validation": {"burst_count": 3}})
+    def test_ratelimit_by_ip(self):
+        """Tests that adding emails is ratelimited by IP
+        """
+
+        # We expect to be able to set three emails before getting ratelimited.
+        self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar"))
+        self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar"))
+        self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar"))
+
+        with self.assertRaises(HttpResponseException) as cm:
+            self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar"))
+
+        self.assertEqual(cm.exception.code, 429)
+
     def test_add_email_if_disabled(self):
         """Test adding email to profile when doing so is disallowed
         """
@@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             body["next_link"] = next_link
 
         channel = self.make_request("POST", b"account/3pid/email/requestToken", body,)
-        self.assertEquals(expect_code, channel.code, channel.result)
+
+        if channel.code != expect_code:
+            raise HttpResponseException(
+                channel.code, channel.result["reason"], channel.result["body"],
+            )
 
         return channel.json_body.get("sid")
 
@@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def _add_email(self, request_email, expected_email):
         """Test adding an email to profile
         """
+        previous_email_attempts = len(self.email_attempts)
+
         client_secret = "foobar"
         session_id = self._request_token(request_email, client_secret)
 
-        self.assertEquals(len(self.email_attempts), 1)
+        self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1)
         link = self._get_link_from_email()
 
         self._validate_token(link)
@@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
         self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
-        self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"])
+
+        threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
+        self.assertIn(expected_email, threepids)
diff --git a/tests/server.py b/tests/server.py
index 5a85d5fe7f..6419c445ec 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -47,6 +47,7 @@ class FakeChannel:
     site = attr.ib(type=Site)
     _reactor = attr.ib()
     result = attr.ib(type=dict, default=attr.Factory(dict))
+    _ip = attr.ib(type=str, default="127.0.0.1")
     _producer = None
 
     @property
@@ -120,7 +121,7 @@ class FakeChannel:
     def getPeer(self):
         # We give an address so that getClientIP returns a non null entry,
         # causing us to record the MAU
-        return address.IPv4Address("TCP", "127.0.0.1", 3423)
+        return address.IPv4Address("TCP", self._ip, 3423)
 
     def getHost(self):
         return None
@@ -196,6 +197,7 @@ def make_request(
     custom_headers: Optional[
         Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
     ] = None,
+    client_ip: str = "127.0.0.1",
 ) -> FakeChannel:
     """
     Make a web request using the given method, path and content, and render it
@@ -223,6 +225,9 @@ def make_request(
              will pump the reactor until the the renderer tells the channel the request
              is finished.
 
+        client_ip: The IP to use as the requesting IP. Useful for testing
+            ratelimiting.
+
     Returns:
         channel
     """
@@ -250,7 +255,7 @@ def make_request(
     if isinstance(content, str):
         content = content.encode("utf8")
 
-    channel = FakeChannel(site, reactor)
+    channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel)
     req.content = BytesIO(content)
diff --git a/tests/unittest.py b/tests/unittest.py
index bbd295687c..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase):
         custom_headers: Optional[
             Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
         ] = None,
+        client_ip: str = "127.0.0.1",
     ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase):
 
             custom_headers: (name, value) pairs to add as request headers
 
+            client_ip: The IP to use as the requesting IP. Useful for testing
+                ratelimiting.
+
         Returns:
             The FakeChannel object which stores the result of the request.
         """
@@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase):
             content_is_form,
             await_result,
             custom_headers,
+            client_ip,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):
diff --git a/tests/utils.py b/tests/utils.py
index 022223cf24..68033d7535 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -157,6 +157,7 @@ def default_config(name, parse=False):
             "local": {"per_second": 10000, "burst_count": 10000},
             "remote": {"per_second": 10000, "burst_count": 10000},
         },
+        "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
         "saml2_enabled": False,
         "default_identity_server": None,
         "key_refresh_interval": 24 * 60 * 60 * 1000,