diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 2d845d0d5c..efc926d094 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -12,14 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import pymacaroons
from netaddr import IPAddress
from twisted.web.server import Request
-import synapse.types
from synapse import event_auth
from synapse.api.auth_blocking import AuthBlocking
from synapse.api.constants import EventTypes, HistoryVisibility, Membership
@@ -36,11 +35,14 @@ from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
-from synapse.types import StateMap, UserID
+from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -68,7 +70,7 @@ class Auth:
The latter should be moved to synapse.handlers.event_auth.EventAuthHandler.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@@ -88,13 +90,13 @@ class Auth:
async def check_from_context(
self, room_version: str, event, context, do_sig_check=True
- ):
+ ) -> None:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_by_id = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_by_id.values()}
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
event_auth.check(
@@ -151,17 +153,11 @@ class Auth:
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
- async def check_host_in_room(self, room_id, host):
+ async def check_host_in_room(self, room_id: str, host: str) -> bool:
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = await self.store.is_host_joined(room_id, host)
- return latest_event_ids
-
- def can_federate(self, event, auth_events):
- creation_event = auth_events.get((EventTypes.Create, ""))
+ return await self.store.is_host_joined(room_id, host)
- return creation_event.content.get("m.federate", True) is True
-
- def get_public_keys(self, invite_event):
+ def get_public_keys(self, invite_event: EventBase) -> List[Dict[str, Any]]:
return event_auth.get_public_keys(invite_event)
async def get_user_by_req(
@@ -170,7 +166,7 @@ class Auth:
allow_guest: bool = False,
rights: str = "access",
allow_expired: bool = False,
- ) -> synapse.types.Requester:
+ ) -> Requester:
"""Get a registered user's ID.
Args:
@@ -196,7 +192,7 @@ class Auth:
access_token = self.get_access_token_from_request(request)
user_id, app_service = await self._get_appservice_user_id(request)
- if user_id:
+ if user_id and app_service:
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
@@ -206,9 +202,7 @@ class Auth:
device_id="dummy-device", # stubbed
)
- requester = synapse.types.create_requester(
- user_id, app_service=app_service
- )
+ requester = create_requester(user_id, app_service=app_service)
request.requester = user_id
opentracing.set_tag("authenticated_entity", user_id)
@@ -251,7 +245,7 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
)
- requester = synapse.types.create_requester(
+ requester = create_requester(
user_info.user_id,
token_id,
is_guest,
@@ -271,7 +265,9 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
- async def _get_appservice_user_id(self, request):
+ async def _get_appservice_user_id(
+ self, request: Request
+ ) -> Tuple[Optional[str], Optional[ApplicationService]]:
app_service = self.store.get_app_service_by_token(
self.get_access_token_from_request(request)
)
@@ -283,6 +279,9 @@ class Auth:
if ip_address not in app_service.ip_range_whitelist:
return None, None
+ # This will always be set by the time Twisted calls us.
+ assert request.args is not None
+
if b"user_id" not in request.args:
return app_service.sender, app_service
@@ -387,7 +386,9 @@ class Auth:
logger.warning("Invalid macaroon in auth: %s %s", type(e), e)
raise InvalidClientTokenError("Invalid macaroon passed.")
- def _parse_and_validate_macaroon(self, token, rights="access"):
+ def _parse_and_validate_macaroon(
+ self, token: str, rights: str = "access"
+ ) -> Tuple[str, bool]:
"""Takes a macaroon and tries to parse and validate it. This is cached
if and only if rights == access and there isn't an expiry.
@@ -432,15 +433,16 @@ class Auth:
return user_id, guest
- def validate_macaroon(self, macaroon, type_string, user_id):
+ def validate_macaroon(
+ self, macaroon: pymacaroons.Macaroon, type_string: str, user_id: str
+ ) -> None:
"""
validate that a Macaroon is understood by and was signed by this server.
Args:
- macaroon(pymacaroons.Macaroon): The macaroon to validate
- type_string(str): The kind of token required (e.g. "access",
- "delete_pusher")
- user_id (str): The user_id required
+ macaroon: The macaroon to validate
+ type_string: The kind of token required (e.g. "access", "delete_pusher")
+ user_id: The user_id required
"""
v = pymacaroons.Verifier()
@@ -465,9 +467,7 @@ class Auth:
if not service:
logger.warning("Unrecognised appservice access token.")
raise InvalidClientTokenError()
- request.requester = synapse.types.create_requester(
- service.sender, app_service=service
- )
+ request.requester = create_requester(service.sender, app_service=service)
return service
async def is_server_admin(self, user: UserID) -> bool:
@@ -519,7 +519,7 @@ class Auth:
return auth_ids
- async def check_can_change_room_list(self, room_id: str, user: UserID):
+ async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the
published room list.
@@ -554,11 +554,11 @@ class Auth:
return user_level >= send_level
@staticmethod
- def has_access_token(request: Request):
+ def has_access_token(request: Request) -> bool:
"""Checks if the request has an access_token.
Returns:
- bool: False if no access_token was given, True otherwise.
+ False if no access_token was given, True otherwise.
"""
# This will always be set by the time Twisted calls us.
assert request.args is not None
@@ -568,13 +568,13 @@ class Auth:
return bool(query_params) or bool(auth_headers)
@staticmethod
- def get_access_token_from_request(request: Request):
+ def get_access_token_from_request(request: Request) -> str:
"""Extracts the access_token from the request.
Args:
request: The http request.
Returns:
- unicode: The access_token
+ The access_token
Raises:
MissingClientTokenError: If there isn't a single access_token in the
request
@@ -649,5 +649,5 @@ class Auth:
% (user_id, room_id),
)
- def check_auth_blocking(self, *args, **kwargs):
- return self._auth_blocking.check_auth_blocking(*args, **kwargs)
+ async def check_auth_blocking(self, *args, **kwargs) -> None:
+ await self._auth_blocking.check_auth_blocking(*args, **kwargs)
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index a8df60cb89..e6bced93d5 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -13,18 +13,21 @@
# limitations under the License.
import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class AuthBlocking:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._server_notices_mxid = hs.config.server_notices_mxid
@@ -43,7 +46,7 @@ class AuthBlocking:
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
- ):
+ ) -> None:
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index c831d9f73c..afc2bc8267 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import List, Optional, Set, Tuple
+from typing import Any, Dict, List, Optional, Set, Tuple
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
@@ -688,7 +688,7 @@ def _verify_third_party_invite(event: EventBase, auth_events: StateMap[EventBase
return False
-def get_public_keys(invite_event):
+def get_public_keys(invite_event: EventBase) -> List[Dict[str, Any]]:
public_keys = []
if "public_key" in invite_event.content:
o = {"public_key": invite_event.content["public_key"]}
|