diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d44e..49966ee3e0 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Tuple
+
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Optional, Tuple
+
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
- def register(self, json_resource):
+ def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- async def on_POST(self, request, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request, txn_id):
+ def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
+ auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id,
+ login_submission,
+ self.auth_handler._sso_login_callback,
+ auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 07903e4017..988f52c78f 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -96,9 +96,14 @@ class Thumbnailer:
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
- # looks awful
- if self.image.mode in ["1", "P"]:
- self.image = self.image.convert("RGB")
+ # looks awful.
+ #
+ # If the image has transparency, use RGBA instead.
+ if self.image.mode in ["1", "L", "P"]:
+ mode = "RGB"
+ if self.image.info.get("transparency", None) is not None:
+ mode = "RGBA"
+ self.image = self.image.convert(mode)
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|