summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/9876.misc1
-rw-r--r--synapse/api/auth.py78
-rw-r--r--synapse/api/auth_blocking.py9
-rw-r--r--synapse/event_auth.py4
4 files changed, 48 insertions, 44 deletions
diff --git a/changelog.d/9876.misc b/changelog.d/9876.misc
new file mode 100644
index 0000000000..28390e32e6
--- /dev/null
+++ b/changelog.d/9876.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.api.auth` and `synapse.api.auth_blocking` modules.
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2d845d0d5c..efc926d094 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,14 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 import pymacaroons
 from netaddr import IPAddress
 
 from twisted.web.server import Request
 
-import synapse.types
 from synapse import event_auth
 from synapse.api.auth_blocking import AuthBlocking
 from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
 from synapse.http.site import SynapseRequest
 from synapse.logging import opentracing as opentracing
 from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import StateMap, UserID
+from synapse.types import Requester, StateMap, UserID, create_requester
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -68,7 +70,7 @@ class Auth:
     The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
@@ -88,13 +90,13 @@ class Auth:
 
     async def check_from_context(
         self, room_version: str, event, context, do_sig_check=True
-    ):
+    ) -> None:
         prev_state_ids = await context.get_prev_state_ids()
         auth_events_ids = self.compute_auth_events(
             event, prev_state_ids, for_verification=True
         )
-        auth_events = await self.store.get_events(auth_events_ids)
-        auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+        auth_events_by_id = await self.store.get_events(auth_events_ids)
+        auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
 
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
         event_auth.check(
@@ -151,17 +153,11 @@ class Auth:
 
         raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
 
-    async def check_host_in_room(self, room_id, host):
+    async def check_host_in_room(self, room_id: str, host: str) -> bool:
         with Measure(self.clock, "check_host_in_room"):
-            latest_event_ids = await self.store.is_host_joined(room_id, host)
-            return latest_event_ids
-
-    def can_federate(self, event, auth_events):
-        creation_event = auth_events.get((EventTypes.Create, ""))
+            return await self.store.is_host_joined(room_id, host)
 
-        return creation_event.content.get("m.federate", True) is True
-
-    def get_public_keys(self, invite_event):
+    def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
         return event_auth.get_public_keys(invite_event)
 
     async def get_user_by_req(
@@ -170,7 +166,7 @@ class Auth:
         allow_guest: bool = False,
         rights: str = "access",
         allow_expired: bool = False,
-    ) -> synapse.types.Requester:
+    ) -> Requester:
         """Get a registered user's ID.
 
         Args:
@@ -196,7 +192,7 @@ class Auth:
             access_token = self.get_access_token_from_request(request)
 
             user_id, app_service = await self._get_appservice_user_id(request)
-            if user_id:
+            if user_id and app_service:
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
                         user_id=user_id,
@@ -206,9 +202,7 @@ class Auth:
                         device_id="dummy-device",  # stubbed
                     )
 
-                requester = synapse.types.create_requester(
-                    user_id, app_service=app_service
-                )
+                requester = create_requester(user_id, app_service=app_service)
 
                 request.requester = user_id
                 opentracing.set_tag("authenticated_entity", user_id)
@@ -251,7 +245,7 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            requester = synapse.types.create_requester(
+            requester = create_requester(
                 user_info.user_id,
                 token_id,
                 is_guest,
@@ -271,7 +265,9 @@ class Auth:
         except KeyError:
             raise MissingClientTokenError()
 
-    async def _get_appservice_user_id(self, request):
+    async def _get_appservice_user_id(
+        self, request: Request
+    ) -> Tuple[Optional[str], Optional[ApplicationService]]:
         app_service = self.store.get_app_service_by_token(
             self.get_access_token_from_request(request)
         )
@@ -283,6 +279,9 @@ class Auth:
             if ip_address not in app_service.ip_range_whitelist:
                 return None, None
 
+        # This will always be set by the time Twisted calls us.
+        assert request.args is not None
+
         if b"user_id" not in request.args:
             return app_service.sender, app_service
 
@@ -387,7 +386,9 @@ class Auth:
             logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
             raise InvalidClientTokenError("Invalid macaroon passed.")
 
-    def _parse_and_validate_macaroon(self, token, rights="access"):
+    def _parse_and_validate_macaroon(
+        self, token: str, rights: str = "access"
+    ) -> Tuple[str, bool]:
         """Takes a macaroon and tries to parse and validate it. This is cached
         if and only if rights == access and there isn't an expiry.
 
@@ -432,15 +433,16 @@ class Auth:
 
         return user_id, guest
 
-    def validate_macaroon(self, macaroon, type_string, user_id):
+    def validate_macaroon(
+        self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
+    ) -> None:
         """
         validate that a Macaroon is understood by and was signed by this server.
 
         Args:
-            macaroon(pymacaroons.Macaroon): The macaroon to validate
-            type_string(str): The kind of token required (e.g. "access",
-                              "delete_pusher")
-            user_id (str): The user_id required
+            macaroon: The macaroon to validate
+            type_string: The kind of token required (e.g. "access", "delete_pusher")
+            user_id: The user_id required
         """
         v = pymacaroons.Verifier()
 
@@ -465,9 +467,7 @@ class Auth:
         if not service:
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
-        request.requester = synapse.types.create_requester(
-            service.sender, app_service=service
-        )
+        request.requester = create_requester(service.sender, app_service=service)
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:
@@ -519,7 +519,7 @@ class Auth:
 
         return auth_ids
 
-    async def check_can_change_room_list(self, room_id: str, user: UserID):
+    async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
         """Determine whether the user is allowed to edit the room's entry in the
         published room list.
 
@@ -554,11 +554,11 @@ class Auth:
         return user_level >= send_level
 
     @staticmethod
-    def has_access_token(request: Request):
+    def has_access_token(request: Request) -> bool:
         """Checks if the request has an access_token.
 
         Returns:
-            bool: False if no access_token was given, True otherwise.
+            False if no access_token was given, True otherwise.
         """
         # This will always be set by the time Twisted calls us.
         assert request.args is not None
@@ -568,13 +568,13 @@ class Auth:
         return bool(query_params) or bool(auth_headers)
 
     @staticmethod
-    def get_access_token_from_request(request: Request):
+    def get_access_token_from_request(request: Request) -> str:
         """Extracts the access_token from the request.
 
         Args:
             request: The http request.
         Returns:
-            unicode: The access_token
+            The access_token
         Raises:
             MissingClientTokenError: If there isn't a single access_token in the
                 request
@@ -649,5 +649,5 @@ class Auth:
                 % (user_id, room_id),
             )
 
-    def check_auth_blocking(self, *args, **kwargs):
-        return self._auth_blocking.check_auth_blocking(*args, **kwargs)
+    async def check_auth_blocking(self, *args, **kwargs) -> None:
+        await self._auth_blocking.check_auth_blocking(*args, **kwargs)
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index a8df60cb89..e6bced93d5 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -13,18 +13,21 @@
 # limitations under the License.
 
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.constants import LimitBlockingTypes, UserTypes
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.config.server import is_threepid_reserved
 from synapse.types import Requester
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AuthBlocking:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
         self._server_notices_mxid = hs.config.server_notices_mxid
@@ -43,7 +46,7 @@ class AuthBlocking:
         threepid: Optional[dict] = None,
         user_type: Optional[str] = None,
         requester: Optional[Requester] = None,
-    ):
+    ) -> None:
         """Checks if the user should be rejected for some external reason,
         such as monthly active user limiting or global disable flag
 
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c831d9f73c..afc2bc8267 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
 
 from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
@@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
     return False
 
 
-def get_public_keys(invite_event):
+def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
     public_keys = []
     if "public_key" in invite_event.content:
         o = {"public_key": invite_event.content["public_key"]}