diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e0feebea94..b67c1702ca 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
+ requester, request, body, "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
user_id = requester.user.to_string()
@@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
- self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # If we have a password in this request, prefer it. Otherwise, there
- # must be a password hash from an earlier request.
+ # If we have a password in this request, prefer it. Otherwise, use the
+ # password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
- else:
+ elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
+ else:
+ # UI validation was skipped, but the request did not include a new
+ # password.
+ password_hash = None
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
@@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase
+ requester.user.to_string(), erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "deactivate your account",
+ requester, request, body, "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase, id_server=body.get("id_server")
+ requester.user.to_string(),
+ erase,
+ requester,
+ id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "add a third-party identifier to your account",
+ requester, request, body, "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
- max_id = await self.store.add_account_data_for_user(
- user_id, account_data_type, body
- )
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = await self.store.add_account_data_to_room(
+ await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
-
- # SSO configuration.
- self._cas_enabled = hs.config.cas_enabled
- if self._cas_enabled:
- self._cas_handler = hs.get_cas_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
- self._oidc_enabled = hs.config.oidc_enabled
- if self._oidc_enabled:
- self._oidc_handler = hs.get_oidc_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
-
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
- if self._cas_enabled:
- # Generate a request to CAS that redirects back to an endpoint
- # to verify the successful authentication.
- sso_redirect_url = self._cas_handler.get_redirect_url(
- {"session": session},
- )
-
- elif self._saml_enabled:
- # Some SAML identity providers (e.g. Google) require a
- # RelayState parameter on requests. It is not necessary here, so
- # pass in a dummy redirect URL (which will never get used).
- client_redirect_url = b"unused"
- sso_redirect_url = self._saml_handler.handle_redirect_request(
- client_redirect_url, session
- )
-
- elif self._oidc_enabled:
- client_redirect_url = b""
- sso_redirect_url = await self._oidc_handler.handle_redirect_request(
- request, client_redirect_url, session
- )
-
- else:
- raise SynapseError(400, "Homeserver not configured for SSO.")
-
- html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+ html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+ LoginType.TERMS, authdict, request.getClientIP()
)
if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "remove device(s) from your account",
+ requester, request, body, "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "remove a device from your account",
+ requester, request, body, "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 75215a3779..28b55f27ad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from functools import wraps
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
+def _validate_group_id(f):
+ """Wrapper to validate the form of the group ID.
+
+ Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+ """
+
+ @wraps(f)
+ def wrapper(self, request, group_id, *args, **kwargs):
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+ return f(self, request, group_id, *args, **kwargs)
+
+ return wrapper
+
+
class GroupServlet(RestServlet):
"""Get the group profile
"""
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
return 200, group_description
+ @_validate_group_id
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
+ @_validate_group_id
async def on_DELETE(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "add a device signing key to your account",
+ requester, request, body, "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5374d2c1b6..f0675abd32 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
+ ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
@@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
- session_id, "registered_user_id", None
+ session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Ensure that the username is valid.
@@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
+ self._registration_flows, request, body, "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
- session_id, "registered_user_id", registered_user_id
+ session_id,
+ UIAuthSessionDataConstants.REGISTERED_USER_ID,
+ registered_user_id,
)
registered = True
@@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return await self._create_registration_details(user_id, body)
+ return await self._create_registration_details(
+ user_id, body, is_appservice_ghost=True,
+ )
- async def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(
+ self, user_id, params, is_appservice_ghost=False
+ ):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest=False,
+ is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
from typing import Tuple
from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
- max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}
|