summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8620.bugfix1
-rw-r--r--mypy.ini1
-rw-r--r--synapse/handlers/account_validity.py29
-rw-r--r--synapse/handlers/profile.py4
-rw-r--r--synapse/storage/databases/main/profile.py4
-rw-r--r--synapse/storage/databases/main/registration.py4
6 files changed, 31 insertions, 12 deletions
diff --git a/changelog.d/8620.bugfix b/changelog.d/8620.bugfix
new file mode 100644
index 0000000000..c1078a3fb5
--- /dev/null
+++ b/changelog.d/8620.bugfix
@@ -0,0 +1 @@
+Fix a bug where the account validity endpoint would silently fail if the user ID did not have an expiration time. It now returns a 400 error.
diff --git a/mypy.ini b/mypy.ini
index 59d9074c3b..1fbd8decf8 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -17,6 +17,7 @@ files =
   synapse/federation,
   synapse/handlers/_base.py,
   synapse/handlers/account_data.py,
+  synapse/handlers/account_validity.py,
   synapse/handlers/appservice.py,
   synapse/handlers/auth.py,
   synapse/handlers/cas_handler.py,
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index fd4f762f33..664d09da1c 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,19 +18,22 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from typing import List
+from typing import TYPE_CHECKING, List
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import StoreError, SynapseError
 from synapse.logging.context import make_deferred_yieldable
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AccountValidityHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.config = hs.config
         self.store = self.hs.get_datastore()
@@ -67,7 +70,7 @@ class AccountValidityHandler:
                 self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
     @wrap_as_background_process("send_renewals")
-    async def _send_renewal_emails(self):
+    async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
         configuration, and sends renewal emails to all of these users as long as they
@@ -81,11 +84,25 @@ class AccountValidityHandler:
                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
                 )
 
-    async def send_renewal_email_to_user(self, user_id: str):
+    async def send_renewal_email_to_user(self, user_id: str) -> None:
+        """
+        Send a renewal email for a specific user.
+
+        Args:
+            user_id: The user ID to send a renewal email for.
+
+        Raises:
+            SynapseError if the user is not set to renew.
+        """
         expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+
+        # If this user isn't set to be expired, raise an error.
+        if expiration_ts is None:
+            raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
+
         await self._send_renewal_email(user_id, expiration_ts)
 
-    async def _send_renewal_email(self, user_id: str, expiration_ts: int):
+    async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
         """Sends out a renewal email to every email address attached to the given user
         with a unique link allowing them to renew their account.
 
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 3875e53c08..14348faaf3 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -131,7 +131,7 @@ class ProfileHandler(BaseHandler):
             profile = await self.store.get_from_remote_profile_cache(user_id)
             return profile or {}
 
-    async def get_displayname(self, target_user: UserID) -> str:
+    async def get_displayname(self, target_user: UserID) -> Optional[str]:
         if self.hs.is_mine(target_user):
             try:
                 displayname = await self.store.get_profile_displayname(
@@ -218,7 +218,7 @@ class ProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def get_avatar_url(self, target_user: UserID) -> str:
+    async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
         if self.hs.is_mine(target_user):
             try:
                 avatar_url = await self.store.get_profile_avatar_url(
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index a6d1eb908a..0e25ca3d7a 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -39,7 +39,7 @@ class ProfileWorkerStore(SQLBaseStore):
             avatar_url=profile["avatar_url"], display_name=profile["displayname"]
         )
 
-    async def get_profile_displayname(self, user_localpart: str) -> str:
+    async def get_profile_displayname(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
@@ -47,7 +47,7 @@ class ProfileWorkerStore(SQLBaseStore):
             desc="get_profile_displayname",
         )
 
-    async def get_profile_avatar_url(self, user_localpart: str) -> str:
+    async def get_profile_avatar_url(self, user_localpart: str) -> Optional[str]:
         return await self.db_pool.simple_select_one_onecol(
             table="profiles",
             keyvalues={"user_id": user_localpart},
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index b0329e17ec..e7b17a7385 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -240,13 +240,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_renewal_token_for_user",
         )
 
-    async def get_users_expiring_soon(self) -> List[Dict[str, int]]:
+    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            A list of dictionaries mapping user ID to expiration time (in milliseconds).
+            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
         """
 
         def select_users_txn(txn, now_ms, renew_at):