diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 03a215ab1b..f17fda6315 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -23,7 +23,8 @@ from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -39,6 +40,10 @@ AuthEventTypes = (
GUEST_DEVICE_ID = "guest_device"
+class _InvalidMacaroonException(Exception):
+ pass
+
+
class Auth(object):
"""
FIXME: This class contains a mix of functions for authenticating users
@@ -51,6 +56,9 @@ class Auth(object):
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
+ self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000)
+ register_cache("token_cache", self.token_cache)
+
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
auth_events_ids = yield self.compute_auth_events(
@@ -144,17 +152,8 @@ class Auth(object):
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
with Measure(self.clock, "check_host_in_room"):
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
-
- logger.debug("calling resolve_state_groups from check_host_in_room")
- entry = yield self.state.resolve_state_groups(
- room_id, latest_event_ids
- )
-
- ret = yield self.store.is_host_joined(
- room_id, host, entry.state_group, entry.state
- )
- defer.returnValue(ret)
+ latest_event_ids = yield self.store.is_host_joined(room_id, host)
+ defer.returnValue(latest_event_ids)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
@@ -205,13 +204,12 @@ class Auth(object):
ip_addr = self.hs.get_ip_from_request(request)
user_agent = request.requestHeaders.getRawHeaders(
- "User-Agent",
- default=[""]
+ b"User-Agent",
+ default=[b""]
)[0]
if user and access_token and ip_addr:
- preserve_context_over_fn(
- self.store.insert_client_ip,
- user=user,
+ self.store.insert_client_ip(
+ user_id=user.to_string(),
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
@@ -272,13 +270,17 @@ class Auth(object):
rights (str): The operation being performed; the access token must
allow this.
Returns:
- dict : dict that includes the user and the ID of their access token.
+ Deferred[dict]: dict that includes:
+ `user` (UserID)
+ `is_guest` (bool)
+ `token_id` (int|None): access token id. May be None if guest
+ `device_id` (str|None): device corresponding to access token
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
try:
- macaroon = pymacaroons.Macaroon.deserialize(token)
- except Exception: # deserialize can throw more-or-less anything
+ user_id, guest = self._parse_and_validate_macaroon(token, rights)
+ except _InvalidMacaroonException:
# doesn't look like a macaroon: treat it as an opaque token which
# must be in the database.
# TODO: it would be nice to get rid of this, but apparently some
@@ -287,19 +289,8 @@ class Auth(object):
defer.returnValue(r)
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
user = UserID.from_string(user_id)
- self.validate_macaroon(
- macaroon, rights, self.hs.config.expire_access_token,
- user_id=user_id,
- )
-
- guest = False
- for caveat in macaroon.caveats:
- if caveat.caveat_id == "guest = true":
- guest = True
-
if guest:
# Guest access tokens are not stored in the database (there can
# only be one access token per guest, anyway).
@@ -371,6 +362,55 @@ class Auth(object):
errcode=Codes.UNKNOWN_TOKEN
)
+ def _parse_and_validate_macaroon(self, token, rights="access"):
+ """Takes a macaroon and tries to parse and validate it. This is cached
+ if and only if rights == access and there isn't an expiry.
+
+ On invalid macaroon raises _InvalidMacaroonException
+
+ Returns:
+ (user_id, is_guest)
+ """
+ if rights == "access":
+ cached = self.token_cache.get(token, None)
+ if cached:
+ return cached
+
+ try:
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ except Exception: # deserialize can throw more-or-less anything
+ # doesn't look like a macaroon: treat it as an opaque token which
+ # must be in the database.
+ # TODO: it would be nice to get rid of this, but apparently some
+ # people use access tokens which aren't macaroons
+ raise _InvalidMacaroonException()
+
+ try:
+ user_id = self.get_user_id_from_macaroon(macaroon)
+
+ has_expiry = False
+ guest = False
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith("time "):
+ has_expiry = True
+ elif caveat.caveat_id == "guest = true":
+ guest = True
+
+ self.validate_macaroon(
+ macaroon, rights, self.hs.config.expire_access_token,
+ user_id=user_id,
+ )
+ except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ raise AuthError(
+ self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
+ errcode=Codes.UNKNOWN_TOKEN
+ )
+
+ if not has_expiry and rights == "access":
+ self.token_cache[token] = (user_id, guest)
+
+ return user_id, guest
+
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
@@ -483,6 +523,14 @@ class Auth(object):
)
def is_server_admin(self, user):
+ """ Check if the given user is a local server admin.
+
+ Args:
+ user (str): mxid of user to check
+
+ Returns:
+ bool: True if the user is an admin
+ """
return self.store.is_server_admin(user)
@defer.inlineCallbacks
@@ -624,7 +672,7 @@ def has_access_token(request):
bool: False if no access_token was given, True otherwise.
"""
query_params = request.args.get("access_token")
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
@@ -644,8 +692,8 @@ def get_access_token_from_request(request, token_not_found_http_status=401):
AuthError: If there isn't an access_token in the request.
"""
- auth_headers = request.requestHeaders.getRawHeaders("Authorization")
- query_params = request.args.get("access_token")
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+ query_params = request.args.get(b"access_token")
if auth_headers:
# Try the get the access_token from a "Authorization: Bearer"
# header
|