diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..5f4c6703db 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
- self._address_ratelimiter.ratelimit(request.getClientIP())
-
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
+
+ if appservice.is_rate_limited():
+ self._address_ratelimiter.ratelimit(request.getClientIP())
+
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
- return await self._complete_login(qualified_user_id, login_submission)
+ return await self._complete_login(
+ qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ )
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
+ ratelimit: bool = True,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
+ ratelimit: Whether to ratelimit the login request.
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(user_id.lower())
+ if ratelimit:
+ self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
|