summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/api/auth.py17
-rw-r--r--synapse/config/account_validity.py102
-rw-r--r--synapse/handlers/account_validity.py128
-rw-r--r--synapse/handlers/register.py5
-rw-r--r--synapse/module_api/__init__.py219
-rw-r--r--synapse/module_api/errors.py6
-rw-r--r--synapse/push/pusherpool.py24
-rw-r--r--synapse/rest/admin/users.py24
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py7
9 files changed, 395 insertions, 137 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 8916e6fa2f..05699714ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -62,6 +62,7 @@ class Auth:
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
+        self._account_validity_handler = hs.get_account_validity_handler()
 
         self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
             10000, "token_cache"
@@ -69,9 +70,6 @@ class Auth:
 
         self._auth_blocking = AuthBlocking(self.hs)
 
-        self._account_validity_enabled = (
-            hs.config.account_validity.account_validity_enabled
-        )
         self._track_appservice_user_ips = hs.config.track_appservice_user_ips
         self._macaroon_secret_key = hs.config.macaroon_secret_key
         self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -187,12 +185,17 @@ class Auth:
             shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
-            if self._account_validity_enabled and not allow_expired:
-                if await self.store.is_account_expired(
-                    user_info.user_id, self.clock.time_msec()
+            if not allow_expired:
+                if await self._account_validity_handler.is_user_expired(
+                    user_info.user_id
                 ):
+                    # Raise the error if either an account validity module has determined
+                    # the account has expired, or the legacy account validity
+                    # implementation is enabled and determined the account has expired
                     raise AuthError(
-                        403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
+                        403,
+                        "User account has expired",
+                        errcode=Codes.EXPIRED_ACCOUNT,
                     )
 
             device_id = user_info.device_id
diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py
index 957de7f3a6..6be4eafe55 100644
--- a/synapse/config/account_validity.py
+++ b/synapse/config/account_validity.py
@@ -18,6 +18,21 @@ class AccountValidityConfig(Config):
     section = "account_validity"
 
     def read_config(self, config, **kwargs):
+        """Parses the old account validity config. The config format looks like this:
+
+        account_validity:
+            enabled: true
+            period: 6w
+            renew_at: 1w
+            renew_email_subject: "Renew your %(app)s account"
+            template_dir: "res/templates"
+            account_renewed_html_path: "account_renewed.html"
+            invalid_token_html_path: "invalid_token.html"
+
+        We expect admins to use modules for this feature (which is why it doesn't appear
+        in the sample config file), but we want to keep support for it around for a bit
+        for backwards compatibility.
+        """
         account_validity_config = config.get("account_validity") or {}
         self.account_validity_enabled = account_validity_config.get("enabled", False)
         self.account_validity_renew_by_email_enabled = (
@@ -75,90 +90,3 @@ class AccountValidityConfig(Config):
             ],
             account_validity_template_dir,
         )
-
-    def generate_config_section(self, **kwargs):
-        return """\
-        ## Account Validity ##
-
-        # Optional account validity configuration. This allows for accounts to be denied
-        # any request after a given period.
-        #
-        # Once this feature is enabled, Synapse will look for registered users without an
-        # expiration date at startup and will add one to every account it found using the
-        # current settings at that time.
-        # This means that, if a validity period is set, and Synapse is restarted (it will
-        # then derive an expiration date from the current validity period), and some time
-        # after that the validity period changes and Synapse is restarted, the users'
-        # expiration dates won't be updated unless their account is manually renewed. This
-        # date will be randomly selected within a range [now + period - d ; now + period],
-        # where d is equal to 10% of the validity period.
-        #
-        account_validity:
-          # The account validity feature is disabled by default. Uncomment the
-          # following line to enable it.
-          #
-          #enabled: true
-
-          # The period after which an account is valid after its registration. When
-          # renewing the account, its validity period will be extended by this amount
-          # of time. This parameter is required when using the account validity
-          # feature.
-          #
-          #period: 6w
-
-          # The amount of time before an account's expiry date at which Synapse will
-          # send an email to the account's email address with a renewal link. By
-          # default, no such emails are sent.
-          #
-          # If you enable this setting, you will also need to fill out the 'email' and
-          # 'public_baseurl' configuration sections.
-          #
-          #renew_at: 1w
-
-          # The subject of the email sent out with the renewal link. '%(app)s' can be
-          # used as a placeholder for the 'app_name' parameter from the 'email'
-          # section.
-          #
-          # Note that the placeholder must be written '%(app)s', including the
-          # trailing 's'.
-          #
-          # If this is not set, a default value is used.
-          #
-          #renew_email_subject: "Renew your %(app)s account"
-
-          # Directory in which Synapse will try to find templates for the HTML files to
-          # serve to the user when trying to renew an account. If not set, default
-          # templates from within the Synapse package will be used.
-          #
-          # The currently available templates are:
-          #
-          # * account_renewed.html: Displayed to the user after they have successfully
-          #       renewed their account.
-          #
-          # * account_previously_renewed.html: Displayed to the user if they attempt to
-          #       renew their account with a token that is valid, but that has already
-          #       been used. In this case the account is not renewed again.
-          #
-          # * invalid_token.html: Displayed to the user when they try to renew an account
-          #       with an unknown or invalid renewal token.
-          #
-          # See https://github.com/matrix-org/synapse/tree/master/synapse/res/templates for
-          # default template contents.
-          #
-          # The file name of some of these templates can be configured below for legacy
-          # reasons.
-          #
-          #template_dir: "res/templates"
-
-          # A custom file name for the 'account_renewed.html' template.
-          #
-          # If not set, the file is assumed to be named "account_renewed.html".
-          #
-          #account_renewed_html_path: "account_renewed.html"
-
-          # A custom file name for the 'invalid_token.html' template.
-          #
-          # If not set, the file is assumed to be named "invalid_token.html".
-          #
-          #invalid_token_html_path: "invalid_token.html"
-        """
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index d752cf34f0..078accd634 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -15,9 +15,11 @@
 import email.mime.multipart
 import email.utils
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
 
-from synapse.api.errors import StoreError, SynapseError
+from twisted.web.http import Request
+
+from synapse.api.errors import AuthError, StoreError, SynapseError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
@@ -27,6 +29,15 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# Types for callbacks to be registered via the module api
+IS_USER_EXPIRED_CALLBACK = Callable[[str], Awaitable[Optional[bool]]]
+ON_USER_REGISTRATION_CALLBACK = Callable[[str], Awaitable]
+# Temporary hooks to allow for a transition from `/_matrix/client` endpoints
+# to `/_synapse/client/account_validity`. See `register_account_validity_callbacks`.
+ON_LEGACY_SEND_MAIL_CALLBACK = Callable[[str], Awaitable]
+ON_LEGACY_RENEW_CALLBACK = Callable[[str], Awaitable[Tuple[bool, bool, int]]]
+ON_LEGACY_ADMIN_REQUEST = Callable[[Request], Awaitable]
+
 
 class AccountValidityHandler:
     def __init__(self, hs: "HomeServer"):
@@ -70,6 +81,99 @@ class AccountValidityHandler:
             if hs.config.run_background_tasks:
                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
+        self._is_user_expired_callbacks: List[IS_USER_EXPIRED_CALLBACK] = []
+        self._on_user_registration_callbacks: List[ON_USER_REGISTRATION_CALLBACK] = []
+        self._on_legacy_send_mail_callback: Optional[
+            ON_LEGACY_SEND_MAIL_CALLBACK
+        ] = None
+        self._on_legacy_renew_callback: Optional[ON_LEGACY_RENEW_CALLBACK] = None
+
+        # The legacy admin requests callback isn't a protected attribute because we need
+        # to access it from the admin servlet, which is outside of this handler.
+        self.on_legacy_admin_request_callback: Optional[ON_LEGACY_ADMIN_REQUEST] = None
+
+    def register_account_validity_callbacks(
+        self,
+        is_user_expired: Optional[IS_USER_EXPIRED_CALLBACK] = None,
+        on_user_registration: Optional[ON_USER_REGISTRATION_CALLBACK] = None,
+        on_legacy_send_mail: Optional[ON_LEGACY_SEND_MAIL_CALLBACK] = None,
+        on_legacy_renew: Optional[ON_LEGACY_RENEW_CALLBACK] = None,
+        on_legacy_admin_request: Optional[ON_LEGACY_ADMIN_REQUEST] = None,
+    ):
+        """Register callbacks from module for each hook."""
+        if is_user_expired is not None:
+            self._is_user_expired_callbacks.append(is_user_expired)
+
+        if on_user_registration is not None:
+            self._on_user_registration_callbacks.append(on_user_registration)
+
+        # The builtin account validity feature exposes 3 endpoints (send_mail, renew, and
+        # an admin one). As part of moving the feature into a module, we need to change
+        # the path from /_matrix/client/unstable/account_validity/... to
+        # /_synapse/client/account_validity, because:
+        #
+        #   * the feature isn't part of the Matrix spec thus shouldn't live under /_matrix
+        #   * the way we register servlets means that modules can't register resources
+        #     under /_matrix/client
+        #
+        # We need to allow for a transition period between the old and new endpoints
+        # in order to allow for clients to update (and for emails to be processed).
+        #
+        # Once the email-account-validity module is loaded, it will take control of account
+        # validity by moving the rows from our `account_validity` table into its own table.
+        #
+        # Therefore, we need to allow modules (in practice just the one implementing the
+        # email-based account validity) to temporarily hook into the legacy endpoints so we
+        # can route the traffic coming into the old endpoints into the module, which is
+        # why we have the following three temporary hooks.
+        if on_legacy_send_mail is not None:
+            if self._on_legacy_send_mail_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_send_mail twice")
+
+            self._on_legacy_send_mail_callback = on_legacy_send_mail
+
+        if on_legacy_renew is not None:
+            if self._on_legacy_renew_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_renew twice")
+
+            self._on_legacy_renew_callback = on_legacy_renew
+
+        if on_legacy_admin_request is not None:
+            if self.on_legacy_admin_request_callback is not None:
+                raise RuntimeError("Tried to register on_legacy_admin_request twice")
+
+            self.on_legacy_admin_request_callback = on_legacy_admin_request
+
+    async def is_user_expired(self, user_id: str) -> bool:
+        """Checks if a user has expired against third-party modules.
+
+        Args:
+            user_id: The user to check the expiry of.
+
+        Returns:
+            Whether the user has expired.
+        """
+        for callback in self._is_user_expired_callbacks:
+            expired = await callback(user_id)
+            if expired is not None:
+                return expired
+
+        if self._account_validity_enabled:
+            # If no module could determine whether the user has expired and the legacy
+            # configuration is enabled, fall back to it.
+            return await self.store.is_account_expired(user_id, self.clock.time_msec())
+
+        return False
+
+    async def on_user_registration(self, user_id: str):
+        """Tell third-party modules about a user's registration.
+
+        Args:
+            user_id: The ID of the newly registered user.
+        """
+        for callback in self._on_user_registration_callbacks:
+            await callback(user_id)
+
     @wrap_as_background_process("send_renewals")
     async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time
@@ -95,6 +199,17 @@ class AccountValidityHandler:
         Raises:
             SynapseError if the user is not set to renew.
         """
+        # If a module supports sending a renewal email from here, do that, otherwise do
+        # the legacy dance.
+        if self._on_legacy_send_mail_callback is not None:
+            await self._on_legacy_send_mail_callback(user_id)
+            return
+
+        if not self._account_validity_renew_by_email_enabled:
+            raise AuthError(
+                403, "Account renewal via email is disabled on this server."
+            )
+
         expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
 
         # If this user isn't set to be expired, raise an error.
@@ -209,6 +324,10 @@ class AccountValidityHandler:
         token is considered stale. A token is stale if the 'token_used_ts_ms' db column
         is non-null.
 
+        This method exists to support handling the legacy account validity /renew
+        endpoint. If a module implements the on_legacy_renew callback, then this process
+        is delegated to the module instead.
+
         Args:
             renewal_token: Token sent with the renewal request.
         Returns:
@@ -218,6 +337,11 @@ class AccountValidityHandler:
               * An int representing the user's expiry timestamp as milliseconds since the
                 epoch, or 0 if the token was invalid.
         """
+        # If a module supports triggering a renew from here, do that, otherwise do the
+        # legacy dance.
+        if self._on_legacy_renew_callback is not None:
+            return await self._on_legacy_renew_callback(renewal_token)
+
         try:
             (
                 user_id,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 26ef016179..056fe5e89f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -77,6 +77,7 @@ class RegistrationHandler(BaseHandler):
         self.identity_handler = self.hs.get_identity_handler()
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.macaroon_gen = hs.get_macaroon_generator()
+        self._account_validity_handler = hs.get_account_validity_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self._server_name = hs.hostname
 
@@ -700,6 +701,10 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
 
+            # Only call the account validity module(s) on the main process, to avoid
+            # repeating e.g. database writes on all of the workers.
+            await self._account_validity_handler.on_user_registration(user_id)
+
     async def register_device(
         self,
         user_id: str,
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 308f045700..f3c78089b7 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -12,18 +12,42 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import email.utils
 import logging
-from typing import TYPE_CHECKING, Any, Generator, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+)
+
+import jinja2
 
 from twisted.internet import defer
 from twisted.web.resource import IResource
 
 from synapse.events import EventBase
 from synapse.http.client import SimpleHttpClient
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    respond_with_html,
+)
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.databases.main.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.util import Clock
+from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -33,7 +57,20 @@ This package defines the 'stable' API which can be used by extension modules whi
 are loaded into Synapse.
 """
 
-__all__ = ["errors", "make_deferred_yieldable", "run_in_background", "ModuleApi"]
+__all__ = [
+    "errors",
+    "make_deferred_yieldable",
+    "parse_json_object_from_request",
+    "respond_with_html",
+    "run_in_background",
+    "cached",
+    "UserID",
+    "DatabasePool",
+    "LoggingTransaction",
+    "DirectServeHtmlResource",
+    "DirectServeJsonResource",
+    "ModuleApi",
+]
 
 logger = logging.getLogger(__name__)
 
@@ -52,12 +89,27 @@ class ModuleApi:
         self._server_name = hs.hostname
         self._presence_stream = hs.get_event_sources().sources["presence"]
         self._state = hs.get_state_handler()
+        self._clock = hs.get_clock()  # type: Clock
+        self._send_email_handler = hs.get_send_email_handler()
+
+        try:
+            app_name = self._hs.config.email_app_name
+
+            self._from_string = self._hs.config.email_notif_from % {"app": app_name}
+        except (KeyError, TypeError):
+            # If substitution failed (which can happen if the string contains
+            # placeholders other than just "app", or if the type of the placeholder is
+            # not a string), fall back to the bare strings.
+            self._from_string = self._hs.config.email_notif_from
+
+        self._raw_from = email.utils.parseaddr(self._from_string)[1]
 
         # We expose these as properties below in order to attach a helpful docstring.
         self._http_client: SimpleHttpClient = hs.get_simple_http_client()
         self._public_room_list_manager = PublicRoomListManager(hs)
 
         self._spam_checker = hs.get_spam_checker()
+        self._account_validity_handler = hs.get_account_validity_handler()
 
     #################################################################################
     # The following methods should only be called during the module's initialisation.
@@ -67,6 +119,11 @@ class ModuleApi:
         """Registers callbacks for spam checking capabilities."""
         return self._spam_checker.register_callbacks
 
+    @property
+    def register_account_validity_callbacks(self):
+        """Registers callbacks for account validity capabilities."""
+        return self._account_validity_handler.register_account_validity_callbacks
+
     def register_web_resource(self, path: str, resource: IResource):
         """Registers a web resource to be served at the given path.
 
@@ -101,22 +158,56 @@ class ModuleApi:
         """
         return self._public_room_list_manager
 
-    def get_user_by_req(self, req, allow_guest=False):
+    @property
+    def public_baseurl(self) -> str:
+        """The configured public base URL for this homeserver."""
+        return self._hs.config.public_baseurl
+
+    @property
+    def email_app_name(self) -> str:
+        """The application name configured in the homeserver's configuration."""
+        return self._hs.config.email.email_app_name
+
+    async def get_user_by_req(
+        self,
+        req: SynapseRequest,
+        allow_guest: bool = False,
+        allow_expired: bool = False,
+    ) -> Requester:
         """Check the access_token provided for a request
 
         Args:
-            req (twisted.web.server.Request): Incoming HTTP request
-            allow_guest (bool): True if guest users should be allowed. If this
+            req: Incoming HTTP request
+            allow_guest: True if guest users should be allowed. If this
                 is False, and the access token is for a guest user, an
                 AuthError will be thrown
+            allow_expired: True if expired users should be allowed. If this
+                is False, and the access token is for an expired user, an
+                AuthError will be thrown
+
         Returns:
-            twisted.internet.defer.Deferred[synapse.types.Requester]:
-                the requester for this request
+            The requester for this request
+
         Raises:
-            synapse.api.errors.AuthError: if no user by that token exists,
+            InvalidClientCredentialsError: if no user by that token exists,
                 or the token is invalid.
         """
-        return self._auth.get_user_by_req(req, allow_guest)
+        return await self._auth.get_user_by_req(
+            req,
+            allow_guest,
+            allow_expired=allow_expired,
+        )
+
+    async def is_user_admin(self, user_id: str) -> bool:
+        """Checks if a user is a server admin.
+
+        Args:
+            user_id: The Matrix ID of the user to check.
+
+        Returns:
+            True if the user is a server admin, False otherwise.
+        """
+        return await self._store.is_server_admin(UserID.from_string(user_id))
 
     def get_qualified_user_id(self, username):
         """Qualify a user id, if necessary
@@ -134,6 +225,32 @@ class ModuleApi:
             return username
         return UserID(username, self._hs.hostname).to_string()
 
+    async def get_profile_for_user(self, localpart: str) -> ProfileInfo:
+        """Look up the profile info for the user with the given localpart.
+
+        Args:
+            localpart: The localpart to look up profile information for.
+
+        Returns:
+            The profile information (i.e. display name and avatar URL).
+        """
+        return await self._store.get_profileinfo(localpart)
+
+    async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
+        """Look up the threepids (email addresses and phone numbers) associated with the
+        given Matrix user ID.
+
+        Args:
+            user_id: The Matrix user ID to look up threepids for.
+
+        Returns:
+            A list of threepids, each threepid being represented by a dictionary
+            containing a "medium" key which value is "email" for email addresses and
+            "msisdn" for phone numbers, and an "address" key which value is the
+            threepid's address.
+        """
+        return await self._store.user_get_threepids(user_id)
+
     def check_user_exists(self, user_id):
         """Check if user exists.
 
@@ -464,6 +581,88 @@ class ModuleApi:
                 presence_events, destination
             )
 
+    def looping_background_call(
+        self,
+        f: Callable,
+        msec: float,
+        *args,
+        desc: Optional[str] = None,
+        **kwargs,
+    ):
+        """Wraps a function as a background process and calls it repeatedly.
+
+        Waits `msec` initially before calling `f` for the first time.
+
+        Args:
+            f: The function to call repeatedly. f can be either synchronous or
+                asynchronous, and must follow Synapse's logcontext rules.
+                More info about logcontexts is available at
+                https://matrix-org.github.io/synapse/latest/log_contexts.html
+            msec: How long to wait between calls in milliseconds.
+            *args: Positional arguments to pass to function.
+            desc: The background task's description. Default to the function's name.
+            **kwargs: Key arguments to pass to function.
+        """
+        if desc is None:
+            desc = f.__name__
+
+        if self._hs.config.run_background_tasks:
+            self._clock.looping_call(
+                run_as_background_process,
+                msec,
+                desc,
+                f,
+                *args,
+                **kwargs,
+            )
+        else:
+            logger.warning(
+                "Not running looping call %s as the configuration forbids it",
+                f,
+            )
+
+    async def send_mail(
+        self,
+        recipient: str,
+        subject: str,
+        html: str,
+        text: str,
+    ):
+        """Send an email on behalf of the homeserver.
+
+        Args:
+            recipient: The email address for the recipient.
+            subject: The email's subject.
+            html: The email's HTML content.
+            text: The email's text content.
+        """
+        await self._send_email_handler.send_email(
+            email_address=recipient,
+            subject=subject,
+            app_name=self.email_app_name,
+            html=html,
+            text=text,
+        )
+
+    def read_templates(
+        self,
+        filenames: List[str],
+        custom_template_directory: Optional[str] = None,
+    ) -> List[jinja2.Template]:
+        """Read and load the content of the template files at the given location.
+        By default, Synapse will look for these templates in its configured template
+        directory, but another directory to search in can be provided.
+
+        Args:
+            filenames: The name of the template files to look for.
+            custom_template_directory: An additional directory to look for the files in.
+
+        Returns:
+            A list containing the loaded templates, with the orders matching the one of
+            the filenames parameter.
+        """
+        return self._hs.config.read_templates(filenames, custom_template_directory)
+
 
 class PublicRoomListManager:
     """Contains methods for adding to, removing from and querying whether a room
diff --git a/synapse/module_api/errors.py b/synapse/module_api/errors.py
index 02bbb0be39..98ea911a81 100644
--- a/synapse/module_api/errors.py
+++ b/synapse/module_api/errors.py
@@ -14,5 +14,9 @@
 
 """Exception types which are exposed as part of the stable module API"""
 
-from synapse.api.errors import RedirectException, SynapseError  # noqa: F401
+from synapse.api.errors import (  # noqa: F401
+    InvalidClientCredentialsError,
+    RedirectException,
+    SynapseError,
+)
 from synapse.config._base import ConfigError  # noqa: F401
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 2519ad76db..85621f33ef 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -62,10 +62,6 @@ class PusherPool:
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
-        self._account_validity_enabled = (
-            hs.config.account_validity.account_validity_enabled
-        )
-
         # We shard the handling of push notifications by user ID.
         self._pusher_shard_config = hs.config.push.pusher_shard_config
         self._instance_name = hs.get_instance_name()
@@ -89,6 +85,8 @@ class PusherPool:
         # map from user id to app_id:pushkey to pusher
         self.pushers: Dict[str, Dict[str, Pusher]] = {}
 
+        self._account_validity_handler = hs.get_account_validity_handler()
+
     def start(self) -> None:
         """Starts the pushers off in a background process."""
         if not self._should_start_pushers:
@@ -238,12 +236,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
@@ -268,12 +263,9 @@ class PusherPool:
 
             for u in users_affected:
                 # Don't push if the user account has expired
-                if self._account_validity_enabled:
-                    expired = await self.store.is_account_expired(
-                        u, self.clock.time_msec()
-                    )
-                    if expired:
-                        continue
+                expired = await self._account_validity_handler.is_user_expired(u)
+                if expired:
+                    continue
 
                 if u in self.pushers:
                     for p in self.pushers[u].values():
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 7d75564758..06e6ccee42 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -560,16 +560,24 @@ class AccountValidityRenewServlet(RestServlet):
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
-        body = parse_json_object_from_request(request)
+        if self.account_activity_handler.on_legacy_admin_request_callback:
+            expiration_ts = await (
+                self.account_activity_handler.on_legacy_admin_request_callback(request)
+            )
+        else:
+            body = parse_json_object_from_request(request)
 
-        if "user_id" not in body:
-            raise SynapseError(400, "Missing property 'user_id' in the request body")
+            if "user_id" not in body:
+                raise SynapseError(
+                    400,
+                    "Missing property 'user_id' in the request body",
+                )
 
-        expiration_ts = await self.account_activity_handler.renew_account_for_user(
-            body["user_id"],
-            body.get("expiration_ts"),
-            not body.get("enable_renewal_emails", True),
-        )
+            expiration_ts = await self.account_activity_handler.renew_account_for_user(
+                body["user_id"],
+                body.get("expiration_ts"),
+                not body.get("enable_renewal_emails", True),
+            )
 
         res = {"expiration_ts": expiration_ts}
         return 200, res
diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index 2d1ad3d3fb..3ebe401861 100644
--- a/synapse/rest/client/v2_alpha/account_validity.py
+++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -14,7 +14,7 @@
 
 import logging
 
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import SynapseError
 from synapse.http.server import respond_with_html
 from synapse.http.servlet import RestServlet
 
@@ -92,11 +92,6 @@ class AccountValiditySendMailServlet(RestServlet):
         )
 
     async def on_POST(self, request):
-        if not self.account_validity_renew_by_email_enabled:
-            raise AuthError(
-                403, "Account renewal via email is disabled on this server."
-            )
-
         requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
         await self.account_activity_handler.send_renewal_email_to_user(user_id)