summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/rooms.py48
-rw-r--r--synapse/rest/admin/users.py23
-rw-r--r--synapse/rest/client/v1/login.py25
-rw-r--r--synapse/rest/client/v1/pusher.py15
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/rest/media/v1/_base.py5
-rw-r--r--synapse/rest/media/v1/media_repository.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py6
-rw-r--r--synapse/rest/media/v1/storage_provider.py16
-rw-r--r--synapse/rest/media/v1/upload_resource.py2
10 files changed, 77 insertions, 67 deletions
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..b902af8028 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.constants import EventTypes, JoinRules
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -25,13 +25,17 @@ from synapse.http.servlet import (
     parse_json_object_from_request,
     parse_string,
 )
+from synapse.http.site import SynapseRequest
 from synapse.rest.admin._base import (
     admin_patterns,
     assert_requester_is_admin,
     assert_user_is_admin,
 )
 from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -45,12 +49,14 @@ class ShutdownRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -86,13 +92,15 @@ class DeleteRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_shutdown_handler = hs.get_room_shutdown_handler()
         self.pagination_handler = hs.get_pagination_handler()
 
-    async def on_POST(self, request, room_id):
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -146,12 +154,12 @@ class ListRoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
         self.admin_handler = hs.get_admin_handler()
 
-    async def on_GET(self, request):
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -236,19 +244,24 @@ class RoomRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room_with_stats(room_id)
         if not ret:
             raise NotFoundError("Room not found")
 
-        return 200, ret
+        members = await self.store.get_users_in_room(room_id)
+        ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+        return (200, ret)
 
 
 class RoomMembersRestServlet(RestServlet):
@@ -258,12 +271,14 @@ class RoomMembersRestServlet(RestServlet):
 
     PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
 
-    async def on_GET(self, request, room_id):
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
         await assert_requester_is_admin(self.auth, request)
 
         ret = await self.store.get_room(room_id)
@@ -280,14 +295,16 @@ class JoinRoomAliasServlet(RestServlet):
 
     PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.auth = hs.get_auth()
         self.room_member_handler = hs.get_room_member_handler()
         self.admin_handler = hs.get_admin_handler()
         self.state_handler = hs.get_state_handler()
 
-    async def on_POST(self, request, room_identifier):
+    async def on_POST(
+        self, request: SynapseRequest, room_identifier: str
+    ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         await assert_user_is_admin(self.auth, requester.user)
 
@@ -314,7 +331,6 @@ class JoinRoomAliasServlet(RestServlet):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
             room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
-            room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..6658c2da56 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_GET_PUSHERS_ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class UsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -320,9 +309,9 @@ class UserRestServletV2(RestServlet):
                             data={},
                         )
 
-            if "avatar_url" in body and type(body["avatar_url"]) == str:
+            if "avatar_url" in body and isinstance(body["avatar_url"], str):
                 await self.profile_handler.set_avatar_url(
-                    user_id, requester, body["avatar_url"], True
+                    target_user, requester, body["avatar_url"], True
                 )
 
             ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +409,9 @@ class UserRegisterServlet(RestServlet):
         if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
             raise SynapseError(400, "Invalid user type")
 
+        if "mac" not in body:
+            raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
         got_mac = body["mac"]
 
         want_mac_builder = hmac.new(
@@ -767,10 +759,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.store.get_pushers_by_user_id(user_id)
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
-            for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
 
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..5f4c6703db 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.well_known import WellKnownBuilder
 from synapse.types import JsonDict, UserID
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
     JWT_TYPE_DEPRECATED = "m.login.jwt"
     APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
 
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
         return 200, {"flows": flows}
 
     async def on_POST(self, request: SynapseRequest):
-        self._address_ratelimiter.ratelimit(request.getClientIP())
-
         login_submission = parse_json_object_from_request(request)
 
         try:
             if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
                 appservice = self.auth.get_appservice_by_req(request)
+
+                if appservice.is_rate_limited():
+                    self._address_ratelimiter.ratelimit(request.getClientIP())
+
                 result = await self._do_appservice_login(login_submission, appservice)
             elif self.jwt_enabled and (
                 login_submission["type"] == LoginRestServlet.JWT_TYPE
                 or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
             ):
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_jwt_login(login_submission)
             elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_token_login(login_submission)
             else:
+                self._address_ratelimiter.ratelimit(request.getClientIP())
                 result = await self._do_other_login(login_submission)
         except KeyError:
             raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
         if not appservice.is_interested_in_user(qualified_user_id):
             raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
 
-        return await self._complete_login(qualified_user_id, login_submission)
+        return await self._complete_login(
+            qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+        )
 
     async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
         """Handle non-token/saml/jwt logins
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
         login_submission: JsonDict,
         callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
         create_non_existent_users: bool = False,
+        ratelimit: bool = True,
     ) -> Dict[str, str]:
         """Called when we've successfully authed the user and now need to
         actually login them in (e.g. create devices). This gets called on
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
             callback: Callback function to run after login.
             create_non_existent_users: Whether to create the user if they don't
                 exist. Defaults to False.
+            ratelimit: Whether to ratelimit the login request.
 
         Returns:
             result: Dictionary of account information after successful login.
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
         # Before we actually log them in we check if they've already logged in
         # too often. This happens here rather than before as we don't
         # necessarily know the user before now.
-        self._account_ratelimiter.ratelimit(user_id.lower())
+        if ratelimit:
+            self._account_ratelimiter.ratelimit(user_id.lower())
 
         if create_non_existent_users:
             canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
-ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers}
 
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 722d993811..6b5a1b7109 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet):
 
         # == Normal User Registration == (everyone else)
         if not self._registration_enabled:
-            raise SynapseError(403, "Registration has been disabled")
+            raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
 
         # For regular registration, convert the provided username to lowercase
         # before attempting to register it. This should mean that people who try
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..47c2b44bff 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
     request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
     request.setHeader(b"Content-Length", b"%d" % (file_size,))
 
+    # Tell web crawlers to not index, archive, or follow links in media. This
+    # should help to prevent things in the media repo from showing up in web
+    # search results.
+    request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
 
 # separators as defined in RFC2616. SP and HT are handled separately.
 # see _can_encode_filename_as_token.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..83beb02b05 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
     def __init__(self, hs):
         self.hs = hs
         self.auth = hs.get_auth()
-        self.client = hs.get_http_client()
+        self.client = hs.get_federation_http_client()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..1082389d9b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource):
             logger.debug("No media removed from url cache")
 
 
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+    # If there's no body, nothing useful is going to be found.
+    if not body:
+        return {}
+
     from lxml import etree
 
     try:
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..67f67efde7 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import inspect
 import logging
 import os
 import shutil
@@ -21,6 +20,7 @@ from typing import Optional
 
 from synapse.config._base import Config
 from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
 
 from ._base import FileInfo, Responder
 from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
         if self.store_synchronous:
             # store_file is supposed to return an Awaitable, but guard
             # against improper implementations.
-            result = self.backend.store_file(path, file_info)
-            if inspect.isawaitable(result):
-                return await result
+            return await maybe_awaitable(self.backend.store_file(path, file_info))
         else:
             # TODO: Handle errors.
             async def store():
                 try:
-                    result = self.backend.store_file(path, file_info)
-                    if inspect.isawaitable(result):
-                        return await result
+                    return await maybe_awaitable(
+                        self.backend.store_file(path, file_info)
+                    )
                 except Exception:
                     logger.exception("Error storing file")
 
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
     async def fetch(self, path, file_info):
         # store_file is supposed to return an Awaitable, but guard
         # against improper implementations.
-        result = self.backend.fetch(path, file_info)
-        if inspect.isawaitable(result):
-            return await result
+        return await maybe_awaitable(self.backend.fetch(path, file_info))
 
 
 class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..42febc9afc 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
         requester = await self.auth.get_user_by_req(request)
         # TODO: The checks here are a bit late. The content will have
         # already been uploaded to a tmp file at this point
-        content_length = request.getHeader(b"Content-Length").decode("ascii")
+        content_length = request.getHeader("Content-Length")
         if content_length is None:
             raise SynapseError(msg="Request must specify a Content-Length", code=400)
         if int(content_length) > self.max_upload_size: