diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 307f5f9a94..05699714ee 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -62,16 +62,14 @@ class Auth:
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
+ self._account_validity_handler = hs.get_account_validity_handler()
- self.token_cache = LruCache(
+ self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
- ) # type: LruCache[str, Tuple[str, bool]]
+ )
self._auth_blocking = AuthBlocking(self.hs)
- self._account_validity_enabled = (
- hs.config.account_validity.account_validity_enabled
- )
self._track_appservice_user_ips = hs.config.track_appservice_user_ips
self._macaroon_secret_key = hs.config.macaroon_secret_key
self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
@@ -187,12 +185,17 @@ class Auth:
shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired.
- if self._account_validity_enabled and not allow_expired:
- if await self.store.is_account_expired(
- user_info.user_id, self.clock.time_msec()
+ if not allow_expired:
+ if await self._account_validity_handler.is_user_expired(
+ user_info.user_id
):
+ # Raise the error if either an account validity module has determined
+ # the account has expired, or the legacy account validity
+ # implementation is enabled and determined the account has expired
raise AuthError(
- 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
+ 403,
+ "User account has expired",
+ errcode=Codes.EXPIRED_ACCOUNT,
)
device_id = user_info.device_id
@@ -240,6 +243,37 @@ class Auth:
except KeyError:
raise MissingClientTokenError()
+ async def validate_appservice_can_control_user_id(
+ self, app_service: ApplicationService, user_id: str
+ ):
+ """Validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ app_service: The app service that controls the user
+ user_id: The author MXID that the app service is controlling
+
+ Raises:
+ AuthError: If the application service is not allowed to control the user
+ (user namespace regex does not match, wrong homeserver, etc)
+ or if the user has not been registered yet.
+ """
+
+ # It's ok if the app service is trying to use the sender from their registration
+ if app_service.sender == user_id:
+ pass
+ # Check to make sure the app service is allowed to control the user
+ elif not app_service.is_interested_in_user(user_id):
+ raise AuthError(
+ 403,
+ "Application service cannot masquerade as this user (%s)." % user_id,
+ )
+ # Check to make sure the user is already registered on the homeserver
+ elif not (await self.store.get_user_by_id(user_id)):
+ raise AuthError(
+ 403, "Application service has not registered this user (%s)" % user_id
+ )
+
async def _get_appservice_user_id(
self, request: Request
) -> Tuple[Optional[str], Optional[ApplicationService]]:
@@ -261,13 +295,11 @@ class Auth:
return app_service.sender, app_service
user_id = request.args[b"user_id"][0].decode("utf8")
+ await self.validate_appservice_can_control_user_id(app_service, user_id)
+
if app_service.sender == user_id:
return app_service.sender, app_service
- if not app_service.is_interested_in_user(user_id):
- raise AuthError(403, "Application service cannot masquerade as this user.")
- if not (await self.store.get_user_by_id(user_id)):
- raise AuthError(403, "Application service has not registered this user")
return user_id, app_service
async def get_user_by_access_token(
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index 4cb8bbaf70..054ab14ab6 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -118,7 +118,7 @@ class RedirectException(CodeMessageException):
super().__init__(code=http_code, msg=msg)
self.location = location
- self.cookies = [] # type: List[bytes]
+ self.cookies: List[bytes] = []
class SynapseError(CodeMessageException):
@@ -160,7 +160,7 @@ class ProxiedRequestError(SynapseError):
):
super().__init__(code, msg, errcode)
if additional_fields is None:
- self._additional_fields = {} # type: Dict
+ self._additional_fields: Dict = {}
else:
self._additional_fields = dict(additional_fields)
diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py
index ce49a0ad58..ad1ff6a9df 100644
--- a/synapse/api/filtering.py
+++ b/synapse/api/filtering.py
@@ -289,7 +289,7 @@ class Filter:
room_id = None
ev_type = "m.presence"
contains_url = False
- labels = [] # type: List[str]
+ labels: List[str] = []
else:
sender = event.get("sender", None)
if not sender:
diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py
index b9a10283f4..3e3d09bbd2 100644
--- a/synapse/api/ratelimiting.py
+++ b/synapse/api/ratelimiting.py
@@ -46,9 +46,7 @@ class Ratelimiter:
# * How many times an action has occurred since a point in time
# * The point in time
# * The rate_hz of this particular entry. This can vary per request
- self.actions = (
- OrderedDict()
- ) # type: OrderedDict[Hashable, Tuple[float, int, float]]
+ self.actions: OrderedDict[Hashable, Tuple[float, int, float]] = OrderedDict()
async def can_do_action(
self,
diff --git a/synapse/api/room_versions.py b/synapse/api/room_versions.py
index f6c1c97b40..a20abc5a65 100644
--- a/synapse/api/room_versions.py
+++ b/synapse/api/room_versions.py
@@ -195,7 +195,7 @@ class RoomVersions:
)
-KNOWN_ROOM_VERSIONS = {
+KNOWN_ROOM_VERSIONS: Dict[str, RoomVersion] = {
v.identifier: v
for v in (
RoomVersions.V1,
@@ -209,4 +209,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V7,
)
# Note that we do not include MSC2043 here unless it is enabled in the config.
-} # type: Dict[str, RoomVersion]
+}
|