diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 6f7dc06503..8457db1e22 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,11 +38,14 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
MakeRoomAdminRestServlet,
+ RoomEventContextServlet,
RoomMembersRestServlet,
RoomRestServlet,
+ RoomStateRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -51,6 +56,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -209,6 +215,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
@@ -230,6 +237,9 @@ def register_servlets(hs, http_server):
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index d0c86b204a..ebc587aa06 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -22,8 +22,7 @@ logger = logging.getLogger(__name__)
class DeleteGroupAdminRestServlet(RestServlet):
- """Allows deleting of local groups
- """
+ """Allows deleting of local groups"""
PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 8720b1401f..b996862c05 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -119,8 +119,7 @@ class QuarantineMediaByID(RestServlet):
class ProtectMediaByID(RestServlet):
- """Protect local media from being quarantined.
- """
+ """Protect local media from being quarantined."""
PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
@@ -141,8 +140,7 @@ class ProtectMediaByID(RestServlet):
class ListMediaInRoom(RestServlet):
- """Lists all of the media in a given room.
- """
+ """Lists all of the media in a given room."""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
@@ -180,8 +178,7 @@ class PurgeMediaCacheRestServlet(RestServlet):
class DeleteMediaByID(RestServlet):
- """Delete local media by a given ID. Removes it from this server.
- """
+ """Delete local media by a given ID. Removes it from this server."""
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index ab7cc9102a..1a3a36f6cf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,9 +15,11 @@
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple
+from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
+from synapse.api.filtering import Filter
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -33,6 +35,7 @@ from synapse.rest.admin._base import (
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -292,6 +295,45 @@ class RoomMembersRestServlet(RestServlet):
return 200, ret
+class RoomStateRestServlet(RestServlet):
+ """
+ Get full state within a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ event_ids = await self.store.get_current_state_ids(room_id)
+ events = await self.store.get_events(event_ids.values())
+ now = self.clock.time_msec()
+ room_state = await self._event_serializer.serialize_events(
+ events.values(),
+ now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_aggregations=False,
+ )
+ ret = {"state": room_state}
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
@@ -431,7 +473,18 @@ class MakeRoomAdminRestServlet(RestServlet):
if not admin_users:
raise SynapseError(400, "No local admin user in room")
- admin_user_id = admin_users[-1]
+ admin_user_id = None
+
+ for admin_user in reversed(admin_users):
+ if room_state.get((EventTypes.Member, admin_user)):
+ admin_user_id = admin_user
+ break
+
+ if not admin_user_id:
+ raise SynapseError(
+ 400,
+ "No local admin user in room",
+ )
pl_content = power_levels.content
else:
@@ -440,7 +493,8 @@ class MakeRoomAdminRestServlet(RestServlet):
admin_user_id = create_event.sender
if not self.is_mine_id(admin_user_id):
raise SynapseError(
- 400, "No local admin user in room",
+ 400,
+ "No local admin user in room",
)
# Grant the user power equal to the room admin by attempting to send an
@@ -450,7 +504,8 @@ class MakeRoomAdminRestServlet(RestServlet):
new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
fake_requester = create_requester(
- admin_user_id, authenticated_entity=requester.authenticated_entity,
+ admin_user_id,
+ authenticated_entity=requester.authenticated_entity,
)
try:
@@ -499,3 +554,122 @@ class MakeRoomAdminRestServlet(RestServlet):
)
return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.store = hs.get_datastore()
+
+ async def resolve_room_id(self, room_identifier: str) -> str:
+ """Resolve to a room ID, if necessary."""
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id
+
+ async def on_DELETE(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
+
+
+class RoomEventContextServlet(RestServlet):
+ """
+ Provide the context for an event.
+ This API is designed to be used when system administrators wish to look at
+ an abuse report and understand what happened during and immediately prior
+ to this event.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
+
+ def __init__(self, hs):
+ super().__init__()
+ self.clock = hs.get_clock()
+ self.room_context_handler = hs.get_room_context_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+ self.auth = hs.get_auth()
+
+ async def on_GET(self, request, room_id, event_id):
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ limit = parse_integer(request, "limit", default=10)
+
+ # picking the API shape for symmetry with /messages
+ filter_str = parse_string(request, b"filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter = Filter(
+ json_decoder.decode(filter_json)
+ ) # type: Optional[Filter]
+ else:
+ event_filter = None
+
+ results = await self.room_context_handler.get_event_context(
+ requester,
+ room_id,
+ event_id,
+ limit,
+ event_filter,
+ use_admin_priviledge=True,
+ )
+
+ if not results:
+ raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
+
+ time_now = self.clock.time_msec()
+ results["events_before"] = await self._event_serializer.serialize_events(
+ results["events_before"], time_now
+ )
+ results["event"] = await self._event_serializer.serialize_event(
+ results["event"], time_now
+ )
+ results["events_after"] = await self._event_serializer.serialize_events(
+ results["events_after"], time_now
+ )
+ results["state"] = await self._event_serializer.serialize_events(
+ results["state"], time_now
+ )
+
+ return 200, results
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index f39e3d6d5c..998a0ef671 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -83,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -103,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -564,7 +579,7 @@ class ResetPasswordRestServlet(RestServlet):
}
Returns:
200 OK with empty object if success otherwise an error.
- """
+ """
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
@@ -737,7 +752,7 @@ class PushersRestServlet(RestServlet):
Returns:
pushers: Dictionary containing pushers information.
- total: Number of pushers in dictonary `pushers`.
+ total: Number of pushers in dictionary `pushers`.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
@@ -875,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..6e2fbedd99 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+
+ if self._msc2858_enabled:
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -297,7 +310,9 @@ class LoginRestServlet(RestServlet):
except jwt.PyJWTError as e:
# A JWT error occurred, return some info back to the client.
raise LoginError(
- 403, "JWT validation failed: %s" % (str(e),), errcode=Codes.FORBIDDEN,
+ 403,
+ "JWT validation failed: %s" % (str(e),),
+ errcode=Codes.FORBIDDEN,
)
user = payload.get("sub", None)
@@ -311,8 +326,22 @@ class LoginRestServlet(RestServlet):
return result
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ if idp.idp_brand:
+ e["brand"] = idp.idp_brand
+ return e
+
+
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +353,33 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
- request, client_redirect_url
+ request,
+ client_redirect_url,
+ idp_id,
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 85a66458c5..717c5f2b10 100644
--- a/synapse/rest/client/v1/profile.py
+++ b/synapse/rest/client/v1/profile.py
@@ -60,7 +60,9 @@ class ProfileDisplaynameRestServlet(RestServlet):
new_name = content["displayname"]
except Exception:
raise SynapseError(
- code=400, msg="Unable to parse name", errcode=Codes.BAD_JSON,
+ code=400,
+ msg="Unable to parse name",
+ errcode=Codes.BAD_JSON,
)
await self.profile_handler.set_displayname(user, requester, new_name, is_admin)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 89823fcc39..0c148a213d 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -159,7 +159,9 @@ class PushersRemoveRestServlet(RestServlet):
self.notifier.on_new_replication_data()
respond_with_html_bytes(
- request, 200, PushersRemoveRestServlet.SUCCESS_HTML,
+ request,
+ 200,
+ PushersRemoveRestServlet.SUCCESS_HTML,
)
return None
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f95627ee61..9a1df30c29 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -362,7 +362,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
- 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
)
try:
@@ -413,7 +415,9 @@ class PublicRoomListRestServlet(TransactionRestServlet):
parse_and_validate_server_name(server)
except ValueError:
raise SynapseError(
- 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ 400,
+ "Invalid server name: %s" % (server,),
+ Codes.INVALID_PARAM,
)
try:
@@ -650,7 +654,7 @@ class RoomEventContextServlet(RestServlet):
event_filter = None
results = await self.room_context_handler.get_event_context(
- requester.user, room_id, event_id, limit, event_filter
+ requester, room_id, event_id, limit, event_filter
)
if not results:
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 65e68d641b..adf1d39728 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -191,7 +193,10 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "modify your account password",
+ requester,
+ request,
+ body,
+ "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@@ -310,7 +315,10 @@ class DeactivateAccountRestServlet(RestServlet):
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "deactivate your account",
+ requester,
+ request,
+ body,
+ "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(),
@@ -379,6 +387,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -430,7 +440,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -458,6 +468,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -695,7 +709,10 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "add a third-party identifier to your account",
+ requester,
+ request,
+ body,
+ "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index 314e01dfe4..3d07aadd39 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,7 +83,10 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "remove device(s) from your account",
+ requester,
+ request,
+ body,
+ "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@@ -129,7 +132,10 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "remove a device from your account",
+ requester,
+ request,
+ body,
+ "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
@@ -206,7 +212,9 @@ class DehydratedDeviceServlet(RestServlet):
if "device_data" not in submission:
raise errors.SynapseError(
- 400, "device_data missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_data missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_data"], dict):
raise errors.SynapseError(
@@ -259,11 +267,15 @@ class ClaimDehydratedDeviceServlet(RestServlet):
if "device_id" not in submission:
raise errors.SynapseError(
- 400, "device_id missing", errcode=errors.Codes.MISSING_PARAM,
+ 400,
+ "device_id missing",
+ errcode=errors.Codes.MISSING_PARAM,
)
elif not isinstance(submission["device_id"], str):
raise errors.SynapseError(
- 400, "device_id must be a string", errcode=errors.Codes.INVALID_PARAM,
+ 400,
+ "device_id must be a string",
+ errcode=errors.Codes.INVALID_PARAM,
)
result = await self.device_handler.rehydrate_device(
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 5b5da71815..d3434225cb 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,13 +16,29 @@
import logging
from functools import wraps
-
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import GroupID
+from typing import TYPE_CHECKING, Optional, Tuple
+
+from twisted.web.http import Request
+
+from synapse.api.constants import (
+ MAX_GROUP_CATEGORYID_LENGTH,
+ MAX_GROUP_ROLEID_LENGTH,
+ MAX_GROUPID_LENGTH,
+)
+from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.groups_local import GroupsLocalHandler
+from synapse.http.servlet import (
+ RestServlet,
+ assert_params_in_dict,
+ parse_json_object_from_request,
+)
+from synapse.types import GroupID, JsonDict
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -33,7 +49,7 @@ def _validate_group_id(f):
"""
@wraps(f)
- def wrapper(self, request, group_id, *args, **kwargs):
+ def wrapper(self, request: Request, group_id: str, *args, **kwargs):
if not GroupID.is_valid(group_id):
raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
@@ -43,19 +59,18 @@ def _validate_group_id(f):
class GroupServlet(RestServlet):
- """Get the group profile
- """
+ """Get the group profile"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -66,11 +81,17 @@ class GroupServlet(RestServlet):
return 200, group_description
@_validate_group_id
- async def on_POST(self, request, group_id):
+ async def on_POST(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert_params_in_dict(
+ content, ("name", "avatar_url", "short_description", "long_description")
+ )
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create group profiles."
await self.groups_handler.update_group_profile(
group_id, requester_user_id, content
)
@@ -79,19 +100,18 @@ class GroupServlet(RestServlet):
class GroupSummaryServlet(RestServlet):
- """Get the full group summary
- """
+ """Get the full group summary"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -116,18 +136,34 @@ class GroupSummaryRoomsCatServlet(RestServlet):
"/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: Optional[str], room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if category_id == "":
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_room(
group_id,
requester_user_id,
@@ -139,10 +175,15 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str, room_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group profiles."
resp = await self.groups_handler.delete_group_summary_room(
group_id, requester_user_id, room_id=room_id, category_id=category_id
)
@@ -151,21 +192,22 @@ class GroupSummaryRoomsCatServlet(RestServlet):
class GroupCategoryServlet(RestServlet):
- """Get/add/update/delete a group category
- """
+ """Get/add/update/delete a group category"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, category_id):
+ async def on_GET(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -176,11 +218,27 @@ class GroupCategoryServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, category_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not category_id:
+ raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+ raise SynapseError(
+ 400,
+ "category_id may not be longer than %s characters"
+ % (MAX_GROUP_CATEGORYID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.update_group_category(
group_id, requester_user_id, category_id=category_id, content=content
)
@@ -188,10 +246,15 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, category_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, category_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
resp = await self.groups_handler.delete_group_category(
group_id, requester_user_id, category_id=category_id
)
@@ -200,19 +263,18 @@ class GroupCategoryServlet(RestServlet):
class GroupCategoriesServlet(RestServlet):
- """Get all group categories
- """
+ """Get all group categories"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -224,19 +286,20 @@ class GroupCategoriesServlet(RestServlet):
class GroupRoleServlet(RestServlet):
- """Get/add/update/delete a group role
- """
+ """Get/add/update/delete a group role"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id, role_id):
+ async def on_GET(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -247,11 +310,27 @@ class GroupRoleServlet(RestServlet):
return 200, category
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if not role_id:
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.update_group_role(
group_id, requester_user_id, role_id=role_id, content=content
)
@@ -259,10 +338,15 @@ class GroupRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group roles."
resp = await self.groups_handler.delete_group_role(
group_id, requester_user_id, role_id=role_id
)
@@ -271,19 +355,18 @@ class GroupRoleServlet(RestServlet):
class GroupRolesServlet(RestServlet):
- """Get all group roles
- """
+ """Get all group roles"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -308,18 +391,34 @@ class GroupSummaryUsersRoleServlet(RestServlet):
"/users/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, role_id, user_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, role_id: Optional[str], user_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ if role_id == "":
+ raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+ if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+ raise SynapseError(
+ 400,
+ "role_id may not be longer than %s characters"
+ % (MAX_GROUP_ROLEID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.update_group_summary_user(
group_id,
requester_user_id,
@@ -331,10 +430,15 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
@_validate_group_id
- async def on_DELETE(self, request, group_id, role_id, user_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, role_id: str, user_id: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group summaries."
resp = await self.groups_handler.delete_group_summary_user(
group_id, requester_user_id, user_id=user_id, role_id=role_id
)
@@ -343,19 +447,18 @@ class GroupSummaryUsersRoleServlet(RestServlet):
class GroupRoomServlet(RestServlet):
- """Get all rooms in a group
- """
+ """Get all rooms in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -367,19 +470,18 @@ class GroupRoomServlet(RestServlet):
class GroupUsersServlet(RestServlet):
- """Get all users in a group
- """
+ """Get all users in a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -391,19 +493,18 @@ class GroupUsersServlet(RestServlet):
class GroupInvitedUsersServlet(RestServlet):
- """Get users invited to a group
- """
+ """Get users invited to a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_GET(self, request, group_id):
+ async def on_GET(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -415,23 +516,25 @@ class GroupInvitedUsersServlet(RestServlet):
class GroupSettingJoinPolicyServlet(RestServlet):
- """Set group join policy
- """
+ """Set group join policy"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group join policy."
result = await self.groups_handler.set_group_join_policy(
group_id, requester_user_id, content
)
@@ -440,19 +543,18 @@ class GroupSettingJoinPolicyServlet(RestServlet):
class GroupCreateServlet(RestServlet):
- """Create a group
- """
+ """Create a group"""
PATTERNS = client_patterns("/create_group$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
self.server_name = hs.hostname
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -461,6 +563,19 @@ class GroupCreateServlet(RestServlet):
localpart = content.pop("localpart")
group_id = GroupID(localpart, self.server_name).to_string()
+ if not localpart:
+ raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
+
+ if len(group_id) > MAX_GROUPID_LENGTH:
+ raise SynapseError(
+ 400,
+ "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
+ Codes.INVALID_PARAM,
+ )
+
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot create groups."
result = await self.groups_handler.create_group(
group_id, requester_user_id, content
)
@@ -469,25 +584,29 @@ class GroupCreateServlet(RestServlet):
class GroupAdminRoomsServlet(RestServlet):
- """Add a room to the group
- """
+ """Add a room to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify rooms in a group."
result = await self.groups_handler.add_room_to_group(
group_id, requester_user_id, room_id, content
)
@@ -495,10 +614,15 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
@_validate_group_id
- async def on_DELETE(self, request, group_id, room_id):
+ async def on_DELETE(
+ self, request: Request, group_id: str, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.remove_room_from_group(
group_id, requester_user_id, room_id
)
@@ -507,26 +631,30 @@ class GroupAdminRoomsServlet(RestServlet):
class GroupAdminRoomsConfigServlet(RestServlet):
- """Update the config of a room in a group
- """
+ """Update the config of a room in a group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)"
"/config/(?P<config_key>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, room_id, config_key):
+ async def on_PUT(
+ self, request: Request, group_id: str, room_id: str, config_key: str
+ ):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot modify group categories."
result = await self.groups_handler.update_room_in_group(
group_id, requester_user_id, room_id, config_key, content
)
@@ -535,14 +663,13 @@ class GroupAdminRoomsConfigServlet(RestServlet):
class GroupAdminUsersInviteServlet(RestServlet):
- """Invite a user to the group
- """
+ """Invite a user to the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -551,12 +678,15 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.is_mine_id = hs.is_mine_id
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
config = content.get("config", {})
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot invite users to a group."
result = await self.groups_handler.invite(
group_id, user_id, requester_user_id, config
)
@@ -565,25 +695,27 @@ class GroupAdminUsersInviteServlet(RestServlet):
class GroupAdminUsersKickServlet(RestServlet):
- """Kick a user from the group
- """
+ """Kick a user from the group"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id, user_id):
+ async def on_PUT(self, request: Request, group_id, user_id) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot kick users from a group."
result = await self.groups_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
@@ -592,23 +724,25 @@ class GroupAdminUsersKickServlet(RestServlet):
class GroupSelfLeaveServlet(RestServlet):
- """Leave a joined group
- """
+ """Leave a joined group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot leave a group for a users."
result = await self.groups_handler.remove_user_from_group(
group_id, requester_user_id, requester_user_id, content
)
@@ -617,23 +751,25 @@ class GroupSelfLeaveServlet(RestServlet):
class GroupSelfJoinServlet(RestServlet):
- """Attempt to join a group, or knock
- """
+ """Attempt to join a group, or knock"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot join a user to a group."
result = await self.groups_handler.join_group(
group_id, requester_user_id, content
)
@@ -642,23 +778,25 @@ class GroupSelfJoinServlet(RestServlet):
class GroupSelfAcceptInviteServlet(RestServlet):
- """Accept a group invite
- """
+ """Accept a group invite"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
+ assert isinstance(
+ self.groups_handler, GroupsLocalHandler
+ ), "Workers cannot accept an invite to a group."
result = await self.groups_handler.accept_invite(
group_id, requester_user_id, content
)
@@ -667,19 +805,18 @@ class GroupSelfAcceptInviteServlet(RestServlet):
class GroupSelfUpdatePublicityServlet(RestServlet):
- """Update whether we publicise a users membership of a group
- """
+ """Update whether we publicise a users membership of a group"""
PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
@_validate_group_id
- async def on_PUT(self, request, group_id):
+ async def on_PUT(self, request: Request, group_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -691,19 +828,18 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
class PublicisedGroupsForUserServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request, user_id):
+ async def on_GET(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
result = await self.groups_handler.get_publicised_groups_for_user(user_id)
@@ -712,19 +848,18 @@ class PublicisedGroupsForUserServlet(RestServlet):
class PublicisedGroupsForUsersServlet(RestServlet):
- """Get the list of groups a user is advertising
- """
+ """Get the list of groups a user is advertising"""
PATTERNS = client_patterns("/publicised_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.groups_handler = hs.get_groups_local_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
@@ -736,18 +871,17 @@ class PublicisedGroupsForUsersServlet(RestServlet):
class GroupsForUserServlet(RestServlet):
- """Get all groups the logged in user is joined to
- """
+ """Get all groups the logged in user is joined to"""
PATTERNS = client_patterns("/joined_groups$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -756,7 +890,7 @@ class GroupsForUserServlet(RestServlet):
return 200, result
-def register_servlets(hs, http_server):
+def register_servlets(hs: "HomeServer", http_server):
GroupServlet(hs).register(http_server)
GroupSummaryServlet(hs).register(http_server)
GroupInvitedUsersServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a6134ead8a..f092e5b3a2 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,7 +271,10 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester, request, body, "add a device signing key to your account",
+ requester,
+ request,
+ body,
+ "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index b093183e79..8f68d8dfc8 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -191,6 +193,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
body, ["client_secret", "country", "phone_number", "send_attempt"]
)
client_secret = body["client_secret"]
+ assert_valid_client_secret(client_secret)
country = body["country"]
phone_number = body["phone_number"]
send_attempt = body["send_attempt"]
@@ -205,6 +208,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -287,6 +294,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
sid = parse_string(request, "sid", required=True)
client_secret = parse_string(request, "client_secret", required=True)
+ assert_valid_client_secret(client_secret)
token = parse_string(request, "token", required=True)
# Attempt to validate a 3PID session
@@ -514,7 +522,10 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
- self._registration_flows, request, body, "register a new account",
+ self._registration_flows,
+ request,
+ body,
+ "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@@ -657,7 +668,9 @@ class RegisterRestServlet(RestServlet):
username, as_token
)
return await self._create_registration_details(
- user_id, body, is_appservice_ghost=True,
+ user_id,
+ body,
+ is_appservice_ghost=True,
)
async def _create_registration_details(
diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 18c75738f8..fe765da23c 100644
--- a/synapse/rest/client/v2_alpha/relations.py
+++ b/synapse/rest/client/v2_alpha/relations.py
@@ -244,7 +244,9 @@ class RelationAggregationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
@@ -322,7 +324,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
await self.auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string(), allow_departed_users=True,
+ room_id,
+ requester.user.to_string(),
+ allow_departed_users=True,
)
# This checks that a) the event exists and b) the user is allowed to
diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index bf030e0ff4..147920767f 100644
--- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
+++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class RoomUpgradeRestServlet(RestServlet):
- """Handler for room uprade requests.
+ """Handler for room upgrade requests.
Handles requests of the form:
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
consent_template_directory = hs.config.user_consent_template_dir
+ # TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 31a41e4a27..90bbeca679 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -137,7 +137,7 @@ def add_file_headers(
# section 3.6 [2] to be a `token` or a `quoted-string`, where a `token`
# is (essentially) a single US-ASCII word, and a `quoted-string` is a
# US-ASCII string surrounded by double-quotes, using backslash as an
- # escape charater. Note that %-encoding is *not* permitted.
+ # escape character. Note that %-encoding is *not* permitted.
#
# `filename*` is defined to be an `ext-value`, which is defined in
# RFC5987 section 3.2.1 [3] to be `charset "'" [ language ] "'" value-chars`,
@@ -300,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -312,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -321,6 +323,7 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 3ed219ae43..48f4433155 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -51,7 +51,8 @@ class DownloadResource(DirectServeJsonResource):
b" object-src 'self';",
)
request.setHeader(
- b"Referrer-Policy", b"no-referrer",
+ b"Referrer-Policy",
+ b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 4c9946a616..a0162d4255 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -184,7 +184,7 @@ class MediaRepository:
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
- """Responds to reqests for local media, if exists, or returns 404.
+ """Responds to requests for local media, if exists, or returns 404.
Args:
request: The incoming request.
@@ -306,7 +306,7 @@ class MediaRepository:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
# file_id is the ID we use to track the file locally. If we've already
- # seen the file then reuse the existing ID, otherwise genereate a new
+ # seen the file then reuse the existing ID, otherwise generate a new
# one.
# If we have an entry in the DB, try and look for it
@@ -325,7 +325,10 @@ class MediaRepository:
# Failed to find the file anywhere, lets download it.
try:
- media_info = await self._download_remote_file(server_name, media_id,)
+ media_info = await self._download_remote_file(
+ server_name,
+ media_id,
+ )
except SynapseError:
raise
except Exception as e:
@@ -351,7 +354,11 @@ class MediaRepository:
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
- async def _download_remote_file(self, server_name: str, media_id: str,) -> dict:
+ async def _download_remote_file(
+ self,
+ server_name: str,
+ media_id: str,
+ ) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
@@ -773,7 +780,11 @@ class MediaRepository:
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
- server_name, media_id, t_width, t_height, t_type,
+ server_name,
+ media_id,
+ t_width,
+ t_height,
+ t_type,
)
if not thumbnail_exists:
raise e
@@ -832,7 +843,10 @@ class MediaRepository:
return await self._remove_local_media_from_disk([media_id])
async def delete_old_local_media(
- self, before_ts: int, size_gt: int = 0, keep_profiles: bool = True,
+ self,
+ before_ts: int,
+ size_gt: int = 0,
+ keep_profiles: bool = True,
) -> Tuple[List[str], int]:
"""
Delete local or remote media from this server by size and timestamp. Removes
@@ -849,7 +863,9 @@ class MediaRepository:
A tuple of (list of deleted media IDs, total deleted media IDs).
"""
old_media = await self.store.get_local_media_before(
- before_ts, size_gt, keep_profiles,
+ before_ts,
+ size_gt,
+ keep_profiles,
)
return await self._remove_local_media_from_disk(old_media)
@@ -927,10 +943,10 @@ class MediaRepositoryResource(Resource):
<thumbnail>
- The thumbnail methods are "crop" and "scale". "scale" trys to return an
+ The thumbnail methods are "crop" and "scale". "scale" tries to return an
image where either the width or the height is smaller than the requested
size. The client should then scale and letterbox the image if it needs to
- fit within a given rectangle. "crop" trys to return an image where the
+ fit within a given rectangle. "crop" tries to return an image where the
width and height are close to the requested size and the aspect matches
the requested size. The client should scale the image if it needs to fit
within a given rectangle.
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 89cdd605aa..1057e638be 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -16,13 +16,17 @@ import contextlib
import logging
import os
import shutil
-from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from typing import IO, TYPE_CHECKING, Any, Callable, Optional, Sequence
+
+import attr
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from synapse.api.errors import NotFoundError
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
+from synapse.util import Clock
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
@@ -58,6 +62,8 @@ class MediaStorage:
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
+ self.spam_checker = hs.get_spam_checker()
+ self.clock = hs.get_clock()
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
@@ -79,8 +85,7 @@ class MediaStorage:
return fname
async def write_to_file(self, source: IO, output: IO):
- """Asynchronously write the `source` to `output`.
- """
+ """Asynchronously write the `source` to `output`."""
await defer_to_thread(self.reactor, _write_file_synchronously, source, output)
@contextlib.contextmanager
@@ -127,18 +132,29 @@ class MediaStorage:
f.flush()
f.close()
+ spam = await self.spam_checker.check_media_file_for_spam(
+ ReadableFileWrapper(self.clock, fname), file_info
+ )
+ if spam:
+ logger.info("Blocking media due to spam checker")
+ # Note that we'll delete the stored media, due to the
+ # try/except below. The media also won't be stored in
+ # the DB.
+ raise SpamMediaException()
+
for provider in self.storage_providers:
await provider.store_file(path, file_info)
finished_called[0] = True
yield f, fname, finish
- except Exception:
+ except Exception as e:
try:
os.remove(fname)
except Exception:
pass
- raise
+
+ raise e from None
if not finished_called:
raise Exception("Finished callback not called")
@@ -302,3 +318,38 @@ class FileResponder(Responder):
def __exit__(self, exc_type, exc_val, exc_tb):
self.open_file.close()
+
+
+class SpamMediaException(NotFoundError):
+ """The media was blocked by a spam checker, so we simply 404 the request (in
+ the same way as if it was quarantined).
+ """
+
+
+@attr.s(slots=True)
+class ReadableFileWrapper:
+ """Wrapper that allows reading a file in chunks, yielding to the reactor,
+ and writing to a callback.
+
+ This is simplified `FileSender` that takes an IO object rather than an
+ `IConsumer`.
+ """
+
+ CHUNK_SIZE = 2 ** 14
+
+ clock = attr.ib(type=Clock)
+ path = attr.ib(type=str)
+
+ async def write_chunks_to(self, callback: Callable[[bytes], None]):
+ """Reads the file in chunks and calls the callback with each chunk."""
+
+ with open(self.path, "rb") as file:
+ while True:
+ chunk = file.read(self.CHUNK_SIZE)
+ if not chunk:
+ break
+
+ callback(chunk)
+
+ # We yield to the reactor by sleeping for 0 seconds.
+ await self.clock.sleep(0)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a632099167..6104ef4e46 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -58,7 +58,10 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
+_charset_match = re.compile(br'<\s*meta[^>]*charset\s*=\s*"?([a-z0-9-]+)"?', flags=re.I)
+_xml_encoding_match = re.compile(
+ br'\s*<\s*\?\s*xml[^>]*encoding="([a-z0-9-]+)"', flags=re.I
+)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
OG_TAG_NAME_MAXLEN = 50
@@ -300,24 +303,7 @@ class PreviewUrlResource(DirectServeJsonResource):
with open(media_info["filename"], "rb") as file:
body = file.read()
- encoding = None
-
- # Let's try and figure out if it has an encoding set in a meta tag.
- # Limit it to the first 1kb, since it ought to be in the meta tags
- # at the top.
- match = _charset_match.search(body[:1000])
-
- # If we find a match, it should take precedence over the
- # Content-Type header, so set it here.
- if match:
- encoding = match.group(1).decode("ascii")
-
- # If we don't find a match, we'll look at the HTTP Content-Type, and
- # if that doesn't exist, we'll fall back to UTF-8.
- if not encoding:
- content_match = _content_type_match.match(media_info["media_type"])
- encoding = content_match.group(1) if content_match else "utf-8"
-
+ encoding = get_html_media_encoding(body, media_info["media_type"])
og = decode_and_calc_og(body, media_info["uri"], encoding)
# pre-cache the image for posterity
@@ -386,7 +372,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -403,7 +389,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -594,8 +580,7 @@ class PreviewUrlResource(DirectServeJsonResource):
)
async def _expire_url_cache_data(self) -> None:
- """Clean up expired url cache content, media and thumbnails.
- """
+ """Clean up expired url cache content, media and thumbnails."""
# TODO: Delete from backup media store
assert self._worker_run_media_background_jobs
@@ -689,30 +674,101 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
+def get_html_media_encoding(body: bytes, content_type: str) -> str:
+ """
+ Get the encoding of the body based on the (presumably) HTML body or media_type.
+
+ The precedence used for finding a character encoding is:
+
+ 1. meta tag with a charset declared.
+ 2. The XML document's character encoding attribute.
+ 3. The Content-Type header.
+ 4. Fallback to UTF-8.
+
+ Args:
+ body: The HTML document, as bytes.
+ content_type: The Content-Type header.
+
+ Returns:
+ The character encoding of the body, as a string.
+ """
+ # Limit searches to the first 1kb, since it ought to be at the top.
+ body_start = body[:1024]
+
+ # Let's try and figure out if it has an encoding set in a meta tag.
+ match = _charset_match.search(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # TODO Support <meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
+
+ # If we didn't find a match, see if it an XML document with an encoding.
+ match = _xml_encoding_match.match(body_start)
+ if match:
+ return match.group(1).decode("ascii")
+
+ # If we don't find a match, we'll look at the HTTP Content-Type, and
+ # if that doesn't exist, we'll fall back to UTF-8.
+ content_match = _content_type_match.match(content_type)
+ if content_match:
+ return content_match.group(1)
+
+ return "utf-8"
+
+
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+
+ # The data was successfully parsed, but no tree was found.
+ if tree is None:
+ return {}
+
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
-
- return og
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
-def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index d6880f2e6e..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,7 +16,7 @@
import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
@@ -106,31 +106,17 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
@@ -276,26 +262,64 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
@@ -306,67 +330,117 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int,
desired_method: str,
desired_type: str,
- thumbnail_infos,
- ) -> dict:
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 6da76ae994..1136277794 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -22,6 +22,7 @@ from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+from synapse.rest.media.v1.media_storage import SpamMediaException
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -86,9 +87,14 @@ class UploadResource(DirectServeJsonResource):
# disposition = headers.getRawHeaders(b"Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
- content_uri = await self.media_repo.create_content(
- media_type, upload_name, request.content, content_length, requester.user
- )
+ try:
+ content_uri = await self.media_repo.create_content(
+ media_type, upload_name, request.content, content_length, requester.user
+ )
+ except SpamMediaException:
+ # For uploading of media we want to respond with a 400, instead of
+ # the default 404, as that would just be confusing.
+ raise SynapseError(400, "Bad content")
logger.info("Uploaded content with URI %r", content_uri)
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..9eeb970580 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +12,56 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+ """Builds a resource tree to include synapse-specific client resources
+
+ These are resources which should be loaded on all workers which expose a C-S API:
+ ie, the main process, and any generic workers so configured.
+
+ Returns:
+ map from path to Resource.
+ """
+ resources = {
+ # SSO bits. These are always loaded, whether or not SSO login is actually
+ # enabled (they just won't work very well if it's not)
+ "/_synapse/client/pick_idp": PickIdpResource(hs),
+ "/_synapse/client/pick_username": pick_username_resource(hs),
+ "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+ "/_synapse/client/sso_register": SsoRegisterResource(hs),
+ }
+
+ # provider-specific SSO bits. Only load these if they are enabled, since they
+ # rely on optional dependencies.
+ if hs.config.oidc_enabled:
+ from synapse.rest.synapse.client.oidc import OIDCResource
+
+ resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+ if hs.config.saml2_enabled:
+ from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+ res = SAML2Resource(hs)
+ resources["/_synapse/client/saml2"] = res
+
+ # This is also mounted under '/_matrix' for backwards-compatibility.
+ # To be removed in Synapse v1.32.0.
+ resources["/_matrix/saml2"] = res
+
+ return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..b2e0f93810
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+ """A resource which collects consent to the server's terms from a new user
+
+ This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+ when we are automatically creating a new user due to an SSO login.
+
+ It shows a template which prompts the user to go and read the Ts and Cs, and click
+ a clickybox if they have done so.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+ self._server_name = hs.hostname
+ self._consent_version = hs.config.consent.user_consent_version
+
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ user_id = UserID(session.chosen_localpart, self._server_name)
+ user_profile = {
+ "display_name": session.display_name,
+ }
+
+ template_params = {
+ "user_id": user_id.to_string(),
+ "user_profile": user_profile,
+ "consent_version": self._consent_version,
+ "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+ }
+
+ template = self._jinja_env.get_template("sso_new_user_consent.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
+
+ async def _async_render_POST(self, request: Request):
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ accepted_version = parse_string(request, "accepted_version", required=True)
+ except SynapseError as e:
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
+
+ await self._sso_handler.handle_terms_accepted(
+ request, session_id, accepted_version
+ )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
logger = logging.getLogger(__name__)
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..1af33f0a45 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
@@ -12,19 +12,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
+from typing import TYPE_CHECKING
from synapse.http.server import DirectServeHtmlResource
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
class OIDCCallbackResource(DirectServeHtmlResource):
isLeaf = 1
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self._oidc_handler = hs.get_oidc_handler()
async def _async_render_GET(self, request):
await self._oidc_handler.handle_oidc_callback(request)
+
+ async def _async_render_POST(self, request):
+ # the auth response can be returned via an x-www-form-urlencoded form instead
+ # of GET params, as per
+ # https://openid.net/specs/oauth-v2-form-post-response-mode-1_0.html.
+ await self._oidc_handler.handle_oidc_callback(request)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
index d3b6803e65..96077cfcd1 100644
--- a/synapse/rest/synapse/client/pick_username.py
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -12,42 +12,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import TYPE_CHECKING
-import pkg_resources
+import logging
+from typing import TYPE_CHECKING, List
from twisted.web.http import Request
from twisted.web.resource import Resource
-from twisted.web.static import File
from synapse.api.errors import SynapseError
-from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME
-from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource
-from synapse.http.servlet import parse_string
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
if TYPE_CHECKING:
from synapse.server import HomeServer
+logger = logging.getLogger(__name__)
+
def pick_username_resource(hs: "HomeServer") -> Resource:
"""Factory method to generate the username picker resource.
- This resource gets mounted under /_synapse/client/pick_username. The top-level
- resource is just a File resource which serves up the static files in the resources
- "res" directory, but it has a couple of children:
-
- * "submit", which does the mechanics of registering the new user, and redirects the
- browser back to the client URL
+ This resource gets mounted under /_synapse/client/pick_username and has two
+ children:
- * "check": checks if a userid is free.
+ * "account_details": renders the form and handles the POSTed response
+ * "check": a JSON endpoint which checks if a userid is free.
"""
- # XXX should we make this path customisable so that admins can restyle it?
- base_path = pkg_resources.resource_filename("synapse", "res/username_picker")
-
- res = File(base_path)
- res.putChild(b"submit", SubmitResource(hs))
+ res = Resource()
+ res.putChild(b"account_details", AccountDetailsResource(hs))
res.putChild(b"check", AvailabilityCheckResource(hs))
return res
@@ -61,28 +61,71 @@ class AvailabilityCheckResource(DirectServeJsonResource):
async def _async_render_GET(self, request: Request):
localpart = parse_string(request, "username", required=True)
- session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
- if not session_id:
- raise SynapseError(code=400, msg="missing session_id")
+ session_id = get_username_mapping_session_cookie_from_request(request)
is_available = await self._sso_handler.check_username_availability(
- localpart, session_id.decode("ascii", errors="replace")
+ localpart, session_id
)
return 200, {"available": is_available}
-class SubmitResource(DirectServeHtmlResource):
+class AccountDetailsResource(DirectServeHtmlResource):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._sso_handler = hs.get_sso_handler()
- async def _async_render_POST(self, request: SynapseRequest):
- localpart = parse_string(request, "username", required=True)
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ idp_id = session.auth_provider_id
+ template_params = {
+ "idp": self._sso_handler.get_identity_providers()[idp_id],
+ "user_attributes": {
+ "display_name": session.display_name,
+ "emails": session.emails,
+ },
+ }
+
+ template = self._jinja_env.get_template("sso_auth_account_details.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
- session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME)
- if not session_id:
- raise SynapseError(code=400, msg="missing session_id")
+ async def _async_render_POST(self, request: SynapseRequest):
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ localpart = parse_string(request, "username", required=True)
+ use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+ try:
+ emails_to_use = [
+ val.decode("utf-8") for val in request.args.get(b"use_email", [])
+ ] # type: List[str]
+ except ValueError:
+ raise SynapseError(400, "Query parameter use_email must be utf-8")
+ except SynapseError as e:
+ logger.warning("[session %s] bad param: %s", session_id, e)
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
await self._sso_handler.handle_submit_username_request(
- request, localpart, session_id.decode("ascii", errors="replace")
+ request, session_id, localpart, use_display_name, emails_to_use
)
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
logger = logging.getLogger(__name__)
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+ """A resource which completes SSO registration
+
+ This resource gets mounted at /_synapse/client/sso_register, and is shown
+ after we collect username and/or consent for a new SSO user. It (finally) registers
+ the user, and confirms redirect to the client
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+ await self._sso_handler.register_sso_user(request, session_id)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 241fe746d9..f591cc6c5c 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -34,6 +34,10 @@ class WellKnownBuilder:
self._config = hs.config
def get_well_known(self):
+ # if we don't have a public_baseurl, we can't help much here.
+ if self._config.public_baseurl is None:
+ return None
+
result = {"m.homeserver": {"base_url": self._config.public_baseurl}}
if self._config.default_identity_server:
|