summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth.py123
-rw-r--r--synapse/api/auth_blocking.py33
-rw-r--r--synapse/api/constants.py5
3 files changed, 84 insertions, 77 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 1071a0576e..bfcaf68b2a 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -33,8 +33,8 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
-from synapse.util.caches import register_cache
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
 
@@ -70,8 +70,9 @@ class Auth:
         self.store = hs.get_datastore()
         self.state = hs.get_state_handler()
 
-        self.token_cache = LruCache(10000)
-        register_cache("cache", "token_cache", self.token_cache)
+        self.token_cache = LruCache(
+            10000, "token_cache"
+        )  # type: LruCache[str, Tuple[str, bool]]
 
         self._auth_blocking = AuthBlocking(self.hs)
 
@@ -184,18 +185,12 @@ class Auth:
         """
         try:
             ip_addr = self.hs.get_ip_from_request(request)
-            user_agent = request.requestHeaders.getRawHeaders(
-                b"User-Agent", default=[b""]
-            )[0].decode("ascii", "surrogateescape")
+            user_agent = request.get_user_agent("")
 
             access_token = self.get_access_token_from_request(request)
 
             user_id, app_service = await self._get_appservice_user_id(request)
             if user_id:
-                request.authenticated_entity = user_id
-                opentracing.set_tag("authenticated_entity", user_id)
-                opentracing.set_tag("appservice_id", app_service.id)
-
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
                         user_id=user_id,
@@ -205,31 +200,38 @@ class Auth:
                         device_id="dummy-device",  # stubbed
                     )
 
-                return synapse.types.create_requester(user_id, app_service=app_service)
+                requester = synapse.types.create_requester(
+                    user_id, app_service=app_service
+                )
+
+                request.requester = user_id
+                opentracing.set_tag("authenticated_entity", user_id)
+                opentracing.set_tag("user_id", user_id)
+                opentracing.set_tag("appservice_id", app_service.id)
+
+                return requester
 
             user_info = await self.get_user_by_access_token(
                 access_token, rights, allow_expired=allow_expired
             )
-            user = user_info["user"]
-            token_id = user_info["token_id"]
-            is_guest = user_info["is_guest"]
-            shadow_banned = user_info["shadow_banned"]
+            token_id = user_info.token_id
+            is_guest = user_info.is_guest
+            shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
-                user_id = user.to_string()
-                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
+                if await self.store.is_account_expired(
+                    user_info.user_id, self.clock.time_msec()
+                ):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
 
-            # device_id may not be present if get_user_by_access_token has been
-            # stubbed out.
-            device_id = user_info.get("device_id")
+            device_id = user_info.device_id
 
-            if user and access_token and ip_addr:
+            if access_token and ip_addr:
                 await self.store.insert_client_ip(
-                    user_id=user.to_string(),
+                    user_id=user_info.token_owner,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
@@ -243,19 +245,23 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            request.authenticated_entity = user.to_string()
-            opentracing.set_tag("authenticated_entity", user.to_string())
-            if device_id:
-                opentracing.set_tag("device_id", device_id)
-
-            return synapse.types.create_requester(
-                user,
+            requester = synapse.types.create_requester(
+                user_info.user_id,
                 token_id,
                 is_guest,
                 shadow_banned,
                 device_id,
                 app_service=app_service,
+                authenticated_entity=user_info.token_owner,
             )
+
+            request.requester = requester
+            opentracing.set_tag("authenticated_entity", user_info.token_owner)
+            opentracing.set_tag("user_id", user_info.user_id)
+            if device_id:
+                opentracing.set_tag("device_id", device_id)
+
+            return requester
         except KeyError:
             raise MissingClientTokenError()
 
@@ -286,7 +292,7 @@ class Auth:
 
     async def get_user_by_access_token(
         self, token: str, rights: str = "access", allow_expired: bool = False,
-    ) -> dict:
+    ) -> TokenLookupResult:
         """ Validate access token and get user_id from it
 
         Args:
@@ -295,13 +301,7 @@ class Auth:
                 allow this
             allow_expired: If False, raises an InvalidClientTokenError
                 if the token is expired
-        Returns:
-            dict that includes:
-               `user` (UserID)
-               `is_guest` (bool)
-               `shadow_banned` (bool)
-               `token_id` (int|None): access token id. May be None if guest
-               `device_id` (str|None): device corresponding to access token
+
         Raises:
             InvalidClientTokenError if a user by that token exists, but the token is
                 expired
@@ -311,9 +311,9 @@ class Auth:
 
         if rights == "access":
             # first look in the database
-            r = await self._look_up_user_by_access_token(token)
+            r = await self.store.get_user_by_access_token(token)
             if r:
-                valid_until_ms = r["valid_until_ms"]
+                valid_until_ms = r.valid_until_ms
                 if (
                     not allow_expired
                     and valid_until_ms is not None
@@ -330,7 +330,6 @@ class Auth:
         # otherwise it needs to be a valid macaroon
         try:
             user_id, guest = self._parse_and_validate_macaroon(token, rights)
-            user = UserID.from_string(user_id)
 
             if rights == "access":
                 if not guest:
@@ -356,23 +355,17 @@ class Auth:
                     raise InvalidClientTokenError(
                         "Guest access token used for regular user"
                     )
-                ret = {
-                    "user": user,
-                    "is_guest": True,
-                    "shadow_banned": False,
-                    "token_id": None,
+
+                ret = TokenLookupResult(
+                    user_id=user_id,
+                    is_guest=True,
                     # all guests get the same device id
-                    "device_id": GUEST_DEVICE_ID,
-                }
+                    device_id=GUEST_DEVICE_ID,
+                )
             elif rights == "delete_pusher":
                 # We don't store these tokens in the database
-                ret = {
-                    "user": user,
-                    "is_guest": False,
-                    "shadow_banned": False,
-                    "token_id": None,
-                    "device_id": None,
-                }
+
+                ret = TokenLookupResult(user_id=user_id, is_guest=False)
             else:
                 raise RuntimeError("Unknown rights setting %s", rights)
             return ret
@@ -481,31 +474,15 @@ class Auth:
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
-    async def _look_up_user_by_access_token(self, token):
-        ret = await self.store.get_user_by_access_token(token)
-        if not ret:
-            return None
-
-        # we use ret.get() below because *lots* of unit tests stub out
-        # get_user_by_access_token in a way where it only returns a couple of
-        # the fields.
-        user_info = {
-            "user": UserID.from_string(ret.get("name")),
-            "token_id": ret.get("token_id", None),
-            "is_guest": False,
-            "shadow_banned": ret.get("shadow_banned"),
-            "device_id": ret.get("device_id"),
-            "valid_until_ms": ret.get("valid_until_ms"),
-        }
-        return user_info
-
     def get_appservice_by_req(self, request):
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
         if not service:
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
-        request.authenticated_entity = service.sender
+        request.requester = synapse.types.create_requester(
+            service.sender, app_service=service
+        )
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index d8fafd7cb8..9c227218e0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -14,10 +14,12 @@
 # limitations under the License.
 
 import logging
+from typing import Optional
 
 from synapse.api.constants import LimitBlockingTypes, UserTypes
 from synapse.api.errors import Codes, ResourceLimitError
 from synapse.config.server import is_threepid_reserved
+from synapse.types import Requester
 
 logger = logging.getLogger(__name__)
 
@@ -33,24 +35,47 @@ class AuthBlocking:
         self._max_mau_value = hs.config.max_mau_value
         self._limit_usage_by_mau = hs.config.limit_usage_by_mau
         self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+        self._server_name = hs.hostname
 
-    async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+    async def check_auth_blocking(
+        self,
+        user_id: Optional[str] = None,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        requester: Optional[Requester] = None,
+    ):
         """Checks if the user should be rejected for some external reason,
         such as monthly active user limiting or global disable flag
 
         Args:
-            user_id(str|None): If present, checks for presence against existing
+            user_id: If present, checks for presence against existing
                 MAU cohort
 
-            threepid(dict|None): If present, checks for presence against configured
+            threepid: If present, checks for presence against configured
                 reserved threepid. Used in cases where the user is trying register
                 with a MAU blocked server, normally they would be rejected but their
                 threepid is on the reserved list. user_id and
                 threepid should never be set at the same time.
 
-            user_type(str|None): If present, is used to decide whether to check against
+            user_type: If present, is used to decide whether to check against
                 certain blocking reasons like MAU.
+
+            requester: If present, and the authenticated entity is a user, checks for
+                presence against existing MAU cohort. Passing in both a `user_id` and
+                `requester` is an error.
         """
+        if requester and user_id:
+            raise Exception(
+                "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
+            )
+
+        if requester:
+            if requester.authenticated_entity.startswith("@"):
+                user_id = requester.authenticated_entity
+            elif requester.authenticated_entity == self._server_name:
+                # We never block the server from doing actions on behalf of
+                # users.
+                return
 
         # Never fail an auth check for the server notices users or support user
         # This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 46013cde15..592abd844b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -155,3 +155,8 @@ class EventContentFields:
 class RoomEncryptionAlgorithms:
     MEGOLM_V1_AES_SHA2 = "m.megolm.v1.aes-sha2"
     DEFAULT = MEGOLM_V1_AES_SHA2
+
+
+class AccountDataTypes:
+    DIRECT = "m.direct"
+    IGNORED_USER_LIST = "m.ignored_user_list"