summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-03-10 18:15:56 +0000
committerGitHub <noreply@github.com>2021-03-10 18:15:56 +0000
commita7a379006651ea219c2618c0a47e8f4053a9e769 (patch)
tree69e7b5a439c4304e52dc1170dd74ee80f3053843 /synapse
parentFix the auth provider on the logins metric (#9573) (diff)
downloadsynapse-a7a379006651ea219c2618c0a47e8f4053a9e769.tar.xz
Convert Requester to attrs (#9586)
... because namedtuples suck

Fix up a couple of other annotations to keep mypy happy.
Diffstat (limited to '')
-rw-r--r--synapse/handlers/auth.py5
-rw-r--r--synapse/rest/media/v1/media_repository.py3
-rw-r--r--synapse/storage/databases/main/registration.py6
-rw-r--r--synapse/types.py57
4 files changed, 36 insertions, 35 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index bec0c615d4..fb5f8118f0 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -337,7 +337,8 @@ class AuthHandler(BaseHandler):
                 user is too high to proceed
 
         """
-
+        if not requester.access_token_id:
+            raise ValueError("Cannot validate a user without an access token")
         if self._ui_auth_session_timeout:
             last_validated = await self.store.get_access_token_last_validated(
                 requester.access_token_id
@@ -1213,7 +1214,7 @@ class AuthHandler(BaseHandler):
     async def delete_access_tokens_for_user(
         self,
         user_id: str,
-        except_token_id: Optional[str] = None,
+        except_token_id: Optional[int] = None,
         device_id: Optional[str] = None,
     ):
         """Invalidate access tokens belonging to a user
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 0641924f18..8b4841ed5d 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -35,6 +35,7 @@ from synapse.api.errors import (
 from synapse.config._base import ConfigError
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.types import UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.util.stringutils import random_string
@@ -145,7 +146,7 @@ class MediaRepository:
         upload_name: Optional[str],
         content: IO,
         content_length: int,
-        auth_user: str,
+        auth_user: UserID,
     ) -> str:
         """Store uploaded content for a local user and return the mxc URL
 
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 61a7556e56..eba66ff352 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 import attr
 
@@ -1510,7 +1510,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
     async def user_delete_access_tokens(
         self,
         user_id: str,
-        except_token_id: Optional[str] = None,
+        except_token_id: Optional[int] = None,
         device_id: Optional[str] = None,
     ) -> List[Tuple[str, int, Optional[str]]]:
         """
@@ -1533,7 +1533,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
             items = keyvalues.items()
             where_clause = " AND ".join(k + " = ?" for k, _ in items)
-            values = [v for _, v in items]
+            values = [v for _, v in items]  # type: List[Union[str, int]]
             if except_token_id:
                 where_clause += " AND id != ?"
                 values.append(except_token_id)
diff --git a/synapse/types.py b/synapse/types.py
index 0216d213c7..b08ce90140 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -83,33 +83,32 @@ class ISynapseReactor(
     """The interfaces necessary for Synapse to function."""
 
 
-class Requester(
-    namedtuple(
-        "Requester",
-        [
-            "user",
-            "access_token_id",
-            "is_guest",
-            "shadow_banned",
-            "device_id",
-            "app_service",
-            "authenticated_entity",
-        ],
-    )
-):
+@attr.s(frozen=True, slots=True)
+class Requester:
     """
     Represents the user making a request
 
     Attributes:
-        user (UserID):  id of the user making the request
-        access_token_id (int|None):  *ID* of the access token used for this
+        user:  id of the user making the request
+        access_token_id:  *ID* of the access token used for this
             request, or None if it came via the appservice API or similar
-        is_guest (bool):  True if the user making this request is a guest user
-        shadow_banned (bool):  True if the user making this request has been shadow-banned.
-        device_id (str|None):  device_id which was set at authentication time
-        app_service (ApplicationService|None):  the AS requesting on behalf of the user
+        is_guest:  True if the user making this request is a guest user
+        shadow_banned:  True if the user making this request has been shadow-banned.
+        device_id:  device_id which was set at authentication time
+        app_service:  the AS requesting on behalf of the user
+        authenticated_entity: The entity that authenticated when making the request.
+            This is different to the user_id when an admin user or the server is
+            "puppeting" the user.
     """
 
+    user = attr.ib(type="UserID")
+    access_token_id = attr.ib(type=Optional[int])
+    is_guest = attr.ib(type=bool)
+    shadow_banned = attr.ib(type=bool)
+    device_id = attr.ib(type=Optional[str])
+    app_service = attr.ib(type=Optional["ApplicationService"])
+    authenticated_entity = attr.ib(type=str)
+
     def serialize(self):
         """Converts self to a type that can be serialized as JSON, and then
         deserialized by `deserialize`
@@ -157,23 +156,23 @@ class Requester(
 def create_requester(
     user_id: Union[str, "UserID"],
     access_token_id: Optional[int] = None,
-    is_guest: Optional[bool] = False,
-    shadow_banned: Optional[bool] = False,
+    is_guest: bool = False,
+    shadow_banned: bool = False,
     device_id: Optional[str] = None,
     app_service: Optional["ApplicationService"] = None,
     authenticated_entity: Optional[str] = None,
-):
+) -> Requester:
     """
     Create a new ``Requester`` object
 
     Args:
-        user_id (str|UserID):  id of the user making the request
-        access_token_id (int|None):  *ID* of the access token used for this
+        user_id:  id of the user making the request
+        access_token_id:  *ID* of the access token used for this
             request, or None if it came via the appservice API or similar
-        is_guest (bool):  True if the user making this request is a guest user
-        shadow_banned (bool):  True if the user making this request is shadow-banned.
-        device_id (str|None):  device_id which was set at authentication time
-        app_service (ApplicationService|None):  the AS requesting on behalf of the user
+        is_guest:  True if the user making this request is a guest user
+        shadow_banned:  True if the user making this request is shadow-banned.
+        device_id:  device_id which was set at authentication time
+        app_service:  the AS requesting on behalf of the user
         authenticated_entity: The entity that authenticated when making the request.
             This is different to the user_id when an admin user or the server is
             "puppeting" the user.