summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/login.py169
-rw-r--r--synapse/rest/client/v1/pusher.py15
-rw-r--r--synapse/rest/client/v1/room.py37
-rw-r--r--synapse/rest/client/v2_alpha/account.py66
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py22
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
-rw-r--r--synapse/rest/client/v2_alpha/devices.py12
-rw-r--r--synapse/rest/client/v2_alpha/groups.py48
-rw-r--r--synapse/rest/client/v2_alpha/keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/register.py43
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py3
-rw-r--r--synapse/rest/client/v2_alpha/tags.py11
12 files changed, 248 insertions, 237 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..0fb9419e58 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -30,6 +31,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -42,7 +46,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
 
@@ -57,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -86,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -105,22 +121,27 @@ class LoginRestServlet(RestServlet):
         return 200, {"flows": flows}
 
     async def on_POST(self, request: SynapseRequest):
-        self._address_ratelimiter.ratelimit(request.getClientIP())
-
         login_submission = parse_json_object_from_request(request)
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
+
+                if appservice.is_rate_limited():
+                    self._address_ratelimiter.ratelimit(request.getClientIP())
+
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +180,9 @@ class LoginRestServlet(RestServlet):
         if not appservice.is_interested_in_user(qualified_user_id):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
-        return await self._complete_login(qualified_user_id, login_submission)
+        return await self._complete_login(
+            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+        )
 
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
@@ -194,6 +217,7 @@ class LoginRestServlet(RestServlet):
         login_submission: JsonDict,
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
+        ratelimit: bool = True,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -208,6 +232,7 @@ class LoginRestServlet(RestServlet):
             callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
+            ratelimit: Whether to ratelimit the login request.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -216,7 +241,8 @@ class LoginRestServlet(RestServlet):
         # Before we actually log them in we check if they've already logged in
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
-        self._account_ratelimiter.ratelimit(user_id.lower())
+        if ratelimit:
+            self._account_ratelimiter.ratelimit(user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
@@ -298,48 +324,63 @@ class LoginRestServlet(RestServlet):
         return result
 
 
-class BaseSSORedirectServlet(RestServlet):
-    """Common base class for /login/sso/redirect impls"""
-
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    if idp.idp_brand:
+        e["brand"] = idp.idp_brand
+    return e
+
+
+class SsoRedirectServlet(RestServlet):
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+
+    def __init__(self, hs: "HomeServer"):
+        # make sure that the relevant handlers are instantiated, so that they
+        # register themselves with the main SSOHandler.
+        if hs.config.cas_enabled:
+            hs.get_cas_handler()
+        if hs.config.saml2_enabled:
+            hs.get_saml_handler()
+        if hs.config.oidc_enabled:
+            hs.get_oidc_handler()
+        self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = await self.get_sso_url(request, client_redirect_url)
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
+        client_redirect_url = parse_string(
+            request, "redirectUrl", required=True, encoding=None
+        )
+        sso_url = await self._sso_handler.handle_redirect_request(
+            request, client_redirect_url, idp_id,
+        )
+        logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        """Get the URL to redirect to, to perform SSO auth
-
-        Args:
-            request: The client request to redirect.
-            client_redirect_url: the URL that we should redirect the
-                client to when everything is done
-
-        Returns:
-            URL to redirect to
-        """
-        # to be implemented by subclasses
-        raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
-    def __init__(self, hs):
-        self._cas_handler = hs.get_cas_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._cas_handler.get_redirect_url(
-            {"redirectUrl": client_redirect_url}
-        ).encode("ascii")
-
 
 class CasTicketServlet(RestServlet):
     PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -366,40 +407,8 @@ class CasTicketServlet(RestServlet):
         )
 
 
-class SAMLRedirectServlet(BaseSSORedirectServlet):
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._saml_handler = hs.get_saml_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
-    """Implementation for /login/sso/redirect for the OIDC login flow."""
-
-    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
-    def __init__(self, hs):
-        self._oidc_handler = hs.get_oidc_handler()
-
-    async def get_sso_url(
-        self, request: SynapseRequest, client_redirect_url: bytes
-    ) -> bytes:
-        return await self._oidc_handler.handle_redirect_request(
-            request, client_redirect_url
-        )
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
+    SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
-        CasRedirectServlet(hs).register(http_server)
         CasTicketServlet(hs).register(http_server)
-    elif hs.config.saml2_enabled:
-        SAMLRedirectServlet(hs).register(http_server)
-    elif hs.config.oidc_enabled:
-        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
-ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers}
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 93c06afe27..f95627ee61 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
 
 if TYPE_CHECKING:
     import synapse.server
@@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
             # provided.
             if server:
                 raise e
-            else:
-                pass
 
         limit = parse_integer(request, "limit", 0)
         since_token = parse_string(request, "since", None)
@@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server, limit=limit, since_token=since_token
@@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server and server != self.hs.config.server_name:
+            # Ensure the server is valid.
+            try:
+                parse_and_validate_server_name(server)
+            except ValueError:
+                raise SynapseError(
+                    400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+                )
+
             try:
                 data = await handler.get_remote_public_room_list(
                     server,
@@ -963,25 +977,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
         )
 
 
-def register_servlets(hs, http_server):
+def register_servlets(hs, http_server, is_worker=False):
     RoomStateEventRestServlet(hs).register(http_server)
-    RoomCreateRestServlet(hs).register(http_server)
     RoomMemberListRestServlet(hs).register(http_server)
     JoinedRoomMemberListRestServlet(hs).register(http_server)
     RoomMessageListRestServlet(hs).register(http_server)
     JoinRoomAliasServlet(hs).register(http_server)
-    RoomForgetRestServlet(hs).register(http_server)
     RoomMembershipRestServlet(hs).register(http_server)
     RoomSendEventRestServlet(hs).register(http_server)
     PublicRoomListRestServlet(hs).register(http_server)
     RoomStateRestServlet(hs).register(http_server)
     RoomRedactEventRestServlet(hs).register(http_server)
     RoomTypingRestServlet(hs).register(http_server)
-    SearchRestServlet(hs).register(http_server)
-    JoinedRoomsRestServlet(hs).register(http_server)
-    RoomEventServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
-    RoomAliasListServlet(hs).register(http_server)
+
+    # Some servlets only get registered for the main process.
+    if not is_worker:
+        RoomCreateRestServlet(hs).register(http_server)
+        RoomForgetRestServlet(hs).register(http_server)
+        SearchRestServlet(hs).register(http_server)
+        JoinedRoomsRestServlet(hs).register(http_server)
+        RoomEventServlet(hs).register(http_server)
+        RoomAliasListServlet(hs).register(http_server)
 
 
 def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e0feebea94..b67c1702ca 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
 from typing import TYPE_CHECKING
 from urllib.parse import urlparse
 
-if TYPE_CHECKING:
-    from synapse.app.homeserver import HomeServer
-
 from synapse.api.constants import LoginType
 from synapse.api.errors import (
     Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
     ThreepidValidationError,
 )
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
 
 from ._base import client_patterns, interactive_auth_handler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class EmailPasswordRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/password/email/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.datastore = hs.get_datastore()
@@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         # The email will be sent to the stored address.
         # This avoids a potential account hijack by requesting a password reset to
         # an email address which is controlled by the attacker but which, after
@@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet):
             requester = await self.auth.get_user_by_req(request)
             try:
                 params, session_id = await self.auth_handler.validate_user_via_ui_auth(
-                    requester,
-                    request,
-                    body,
-                    self.hs.get_ip_from_request(request),
-                    "modify your account password",
+                    requester, request, body, "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
                 # The user needs to provide more steps to complete auth, but
@@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
             user_id = requester.user.to_string()
@@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet):
                     [[LoginType.EMAIL_IDENTITY]],
                     request,
                     body,
-                    self.hs.get_ip_from_request(request),
                     "modify your account password",
                 )
             except InteractiveAuthIncompleteError as e:
@@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet):
                 if new_password:
                     password_hash = await self.auth_handler.hash(new_password)
                     await self.auth_handler.set_session_data(
-                        e.session_id, "password_hash", password_hash
+                        e.session_id,
+                        UIAuthSessionDataConstants.PASSWORD_HASH,
+                        password_hash,
                     )
                 raise
 
@@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet):
                 logger.error("Auth succeeded but no known type! %r", result.keys())
                 raise SynapseError(500, "", Codes.UNKNOWN)
 
-        # If we have a password in this request, prefer it. Otherwise, there
-        # must be a password hash from an earlier request.
+        # If we have a password in this request, prefer it. Otherwise, use the
+        # password hash from an earlier request.
         if new_password:
             password_hash = await self.auth_handler.hash(new_password)
-        else:
+        elif session_id is not None:
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
+        else:
+            # UI validation was skipped, but the request did not include a new
+            # password.
+            password_hash = None
         if not password_hash:
             raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
 
@@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet):
         # allow ASes to deactivate their own users
         if requester.app_service:
             await self._deactivate_account_handler.deactivate_account(
-                requester.user.to_string(), erase
+                requester.user.to_string(), erase, requester
             )
             return 200, {}
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "deactivate your account",
+            requester, request, body, "deactivate your account",
         )
         result = await self._deactivate_account_handler.deactivate_account(
-            requester.user.to_string(), erase, id_server=body.get("id_server")
+            requester.user.to_string(),
+            erase,
+            requester,
+            id_server=body.get("id_server"),
         )
         if result:
             id_server_unbind_result = "success"
@@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
 class MsisdnThreepidRequestTokenRestServlet(RestServlet):
     PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         super().__init__()
         self.store = self.hs.get_datastore()
@@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         if next_link:
             # Raise if the provided next_link value isn't valid
             assert_valid_next_link(self.hs, next_link)
@@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet):
         assert_valid_client_secret(client_secret)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a third-party identifier to your account",
+            requester, request, body, "add a third-party identifier to your account",
         )
 
         validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_account_data_for_user(
-            user_id, account_data_type, body
-        )
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_account_data_for_user(user_id, account_data_type, body)
 
         return 200, {}
 
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
-        self._is_worker = hs.config.worker_app is not None
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
-        if self._is_worker:
-            raise Exception("Cannot handle PUT /account_data on worker")
-
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
                 " Use /rooms/!roomId:server.name/read_markers",
             )
 
-        max_id = await self.store.add_account_data_to_room(
+        await self.handler.add_account_data_to_room(
             user_id, room_id, account_data_type, body
         )
 
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
         return 200, {}
 
     async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
         assert_params_in_dict(body, ["devices"])
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove device(s) from your account",
+            requester, request, body, "remove device(s) from your account",
         )
 
         await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
                 raise
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "remove a device from your account",
+            requester, request, body, "remove a device from your account",
         )
 
         await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 75215a3779..28b55f27ad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from functools import wraps
 
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
 logger = logging.getLogger(__name__)
 
 
+def _validate_group_id(f):
+    """Wrapper to validate the form of the group ID.
+
+    Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+    """
+
+    @wraps(f)
+    def wrapper(self, request, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+        return f(self, request, group_id, *args, **kwargs)
+
+    return wrapper
+
+
 class GroupServlet(RestServlet):
     """Get the group profile
     """
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
 
         return 200, group_description
 
+    @_validate_group_id
     async def on_POST(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, category_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, category
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
         return 200, resp
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, role_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
 
-        if not GroupID.is_valid(group_id):
-            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
         result = await self.groups_handler.get_rooms_in_group(
             group_id, requester_user_id
         )
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_GET(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
         self.auth = hs.get_auth()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
 
         return 200, result
 
+    @_validate_group_id
     async def on_DELETE(self, request, group_id, room_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, room_id, config_key):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id, user_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         self.clock = hs.get_clock()
         self.groups_handler = hs.get_groups_local_handler()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
@@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
+    @_validate_group_id
     async def on_PUT(self, request, group_id):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
         body = parse_json_object_from_request(request)
 
         await self.auth_handler.validate_user_via_ui_auth(
-            requester,
-            request,
-            body,
-            self.hs.get_ip_from_request(request),
-            "add a device signing key to your account",
+            requester, request, body, "add a device signing key to your account",
         )
 
         result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5374d2c1b6..f0675abd32 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
 from synapse.config.registration import RegistrationConfig
 from synapse.config.server import is_threepid_reserved
 from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http.server import finish_request, respond_with_html
 from synapse.http.servlet import (
     RestServlet,
@@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "email", email
         )
@@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
                 Codes.THREEPID_DENIED,
             )
 
+        self.identity_handler.ratelimit_request_token_requests(
+            request, "msisdn", msisdn
+        )
+
         existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
             "msisdn", msisdn
         )
@@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
                 403, "Registration has been disabled", errcode=Codes.FORBIDDEN
             )
 
-        ip = self.hs.get_ip_from_request(request)
+        ip = request.getClientIP()
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             await wait_deferred
 
@@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet):
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
-            raise SynapseError(403, "Registration has been disabled")
+            raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
 
         # For regular registration, convert the provided username to lowercase
         # before attempting to register it. This should mean that people who try
@@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet):
             # user here. We carry on and go through the auth checks though,
             # for paranoia.
             registered_user_id = await self.auth_handler.get_session_data(
-                session_id, "registered_user_id", None
+                session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
             )
             # Extract the previously-hashed password from the session.
             password_hash = await self.auth_handler.get_session_data(
-                session_id, "password_hash", None
+                session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
             )
 
         # Ensure that the username is valid.
@@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet):
         # not this will raise a user-interactive auth error.
         try:
             auth_result, params, session_id = await self.auth_handler.check_ui_auth(
-                self._registration_flows,
-                request,
-                body,
-                self.hs.get_ip_from_request(request),
-                "register a new account",
+                self._registration_flows, request, body, "register a new account",
             )
         except InteractiveAuthIncompleteError as e:
             # The user needs to provide more steps to complete auth.
@@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet):
             if not password_hash and password:
                 password_hash = await self.auth_handler.hash(password)
                 await self.auth_handler.set_session_data(
-                    e.session_id, "password_hash", password_hash
+                    e.session_id,
+                    UIAuthSessionDataConstants.PASSWORD_HASH,
+                    password_hash,
                 )
             raise
 
@@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet):
             # Remember that the user account has been registered (and the user
             # ID it was registered with, since it might not have been specified).
             await self.auth_handler.set_session_data(
-                session_id, "registered_user_id", registered_user_id
+                session_id,
+                UIAuthSessionDataConstants.REGISTERED_USER_ID,
+                registered_user_id,
             )
 
             registered = True
@@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet):
         user_id = await self.registration_handler.appservice_register(
             username, as_token
         )
-        return await self._create_registration_details(user_id, body)
+        return await self._create_registration_details(
+            user_id, body, is_appservice_ghost=True,
+        )
 
-    async def _create_registration_details(self, user_id, params):
+    async def _create_registration_details(
+        self, user_id, params, is_appservice_ghost=False
+    ):
         """Complete registration of newly-registered user
 
         Allocates device_id if one was not given; also creates access_token.
@@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet):
             device_id = params.get("device_id")
             initial_display_name = params.get("initial_device_display_name")
             device_id, access_token = await self.registration_handler.register_device(
-                user_id, device_id, initial_display_name, is_guest=False
+                user_id,
+                device_id,
+                initial_display_name,
+                is_guest=False,
+                is_appservice_ghost=is_appservice_ghost,
             )
 
             result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
 from typing import Tuple
 
 from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
 from synapse.logging.opentracing import set_tag, trace
 from synapse.rest.client.transactions import HttpTransactionCache
 
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         content = parse_json_object_from_request(request)
+        assert_params_in_dict(content, ("messages",))
 
         sender_user_id = requester.user.to_string()
 
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
     def __init__(self, hs):
         super().__init__()
         self.auth = hs.get_auth()
-        self.store = hs.get_datastore()
-        self.notifier = hs.get_notifier()
+        self.handler = hs.get_account_data_handler()
 
     async def on_PUT(self, request, user_id, room_id, tag):
         requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.add_tag_to_room(user_id, room_id, tag, body)
 
         return 200, {}
 
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
-        self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+        await self.handler.remove_tag_from_room(user_id, room_id, tag)
 
         return 200, {}