summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/admin/purge_room_servlet.py15
-rw-r--r--synapse/rest/admin/server_notice_servlet.py23
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/media/v1/thumbnailer.py11
4 files changed, 40 insertions, 23 deletions
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py

index 8b7bb6d44e..49966ee3e0 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Tuple + from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer class PurgeRoomServlet(RestServlet): @@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet): PATTERNS = admin_patterns("/purge_room$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Optional, Tuple + from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class SendServerNoticeServlet(RestServlet): @@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet): } """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) self.snm = hs.get_server_notices_manager() - def register(self, json_resource): + def register(self, json_resource: HttpServer): PATTERN = "/send_server_notice" json_resource.register_paths( "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ @@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet): self.__class__.__name__, ) - async def on_POST(self, request, txn_id=None): + async def on_POST( + self, request: SynapseRequest, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) @@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet): return 200, {"event_id": event.event_id} - def on_PUT(self, request, txn_id): + def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet): callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, + auth_provider_id: Optional[str] = None, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: result: Dictionary of account information after successful login. @@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name + user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id ) result = { @@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet): """ token = login_submission["token"] auth_handler = self.auth_handler - user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( - token - ) + res = await auth_handler.validate_short_term_login_token(token) return await self._complete_login( - user_id, login_submission, self.auth_handler._sso_login_callback + res.user_id, + login_submission, + self.auth_handler._sso_login_callback, + auth_provider_id=res.auth_provider_id, ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 07903e4017..988f52c78f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py
@@ -96,9 +96,14 @@ class Thumbnailer: def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which - # looks awful - if self.image.mode in ["1", "P"]: - self.image = self.image.convert("RGB") + # looks awful. + # + # If the image has transparency, use RGBA instead. + if self.image.mode in ["1", "L", "P"]: + mode = "RGB" + if self.image.info.get("transparency", None) is not None: + mode = "RGBA" + self.image = self.image.convert(mode) return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: