diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 55ddebb4fe..f5c5d164f9 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -1,6 +1,8 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018-2019 New Vector Ltd
+# Copyright 2020, 2021 The Matrix.org Foundation C.I.C.
+
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -36,10 +38,13 @@ from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_medi
from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
+ ForwardExtremitiesRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
+ MakeRoomAdminRestServlet,
RoomMembersRestServlet,
RoomRestServlet,
+ RoomStateRestServlet,
ShutdownRoomRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
@@ -50,6 +55,7 @@ from synapse.rest.admin.users import (
PushersRestServlet,
ResetPasswordRestServlet,
SearchUsersRestServlet,
+ ShadowBanRestServlet,
UserAdminServlet,
UserMediaRestServlet,
UserMembershipRestServlet,
@@ -208,6 +214,7 @@ def register_servlets(hs, http_server):
"""
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
+ RoomStateRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
@@ -228,6 +235,9 @@ def register_servlets(hs, http_server):
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
+ MakeRoomAdminRestServlet(hs).register(http_server)
+ ShadowBanRestServlet(hs).register(http_server)
+ ForwardExtremitiesRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(hs, http_server):
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index c82b4f87d6..8720b1401f 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -15,6 +15,9 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Tuple
+
+from twisted.web.http import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
@@ -23,6 +26,10 @@ from synapse.rest.admin._base import (
assert_requester_is_admin,
assert_user_is_admin,
)
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -39,11 +46,11 @@ class QuarantineMediaInRoom(RestServlet):
admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, room_id: str):
+ async def on_POST(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -64,11 +71,11 @@ class QuarantineMediaByUser(RestServlet):
PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, user_id: str):
+ async def on_POST(self, request: Request, user_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -91,11 +98,13 @@ class QuarantineMediaByID(RestServlet):
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_POST(self, request, server_name: str, media_id: str):
+ async def on_POST(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -109,17 +118,39 @@ class QuarantineMediaByID(RestServlet):
return 200, {}
+class ProtectMediaByID(RestServlet):
+ """Protect local media from being quarantined.
+ """
+
+ PATTERNS = admin_patterns("/media/protect/(?P<media_id>[^/]+)")
+
+ def __init__(self, hs: "HomeServer"):
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request: Request, media_id: str) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ logging.info("Protecting local media by ID: %s", media_id)
+
+ # Quarantine this media id
+ await self.store.mark_local_media_as_safe(media_id)
+
+ return 200, {}
+
+
class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
- async def on_GET(self, request, room_id):
+ async def on_GET(self, request: Request, room_id: str) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
is_admin = await self.auth.is_server_admin(requester.user)
if not is_admin:
@@ -133,11 +164,11 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
PATTERNS = admin_patterns("/purge_media_cache")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.media_repository = hs.get_media_repository()
self.auth = hs.get_auth()
- async def on_POST(self, request):
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -154,13 +185,15 @@ class DeleteMediaByID(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_DELETE(self, request, server_name: str, media_id: str):
+ async def on_DELETE(
+ self, request: Request, server_name: str, media_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if self.server_name != server_name:
@@ -182,13 +215,13 @@ class DeleteMediaByDateSize(RestServlet):
PATTERNS = admin_patterns("/media/(?P<server_name>[^/]+)/delete")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.server_name = hs.hostname
self.media_repository = hs.get_media_repository()
- async def on_POST(self, request, server_name: str):
+ async def on_POST(self, request: Request, server_name: str) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
before_ts = parse_integer(request, "before_ts", required=True)
@@ -222,7 +255,7 @@ class DeleteMediaByDateSize(RestServlet):
return 200, {"deleted_media": deleted_media, "total": total}
-def register_servlets_for_media_repo(hs, http_server):
+def register_servlets_for_media_repo(hs: "HomeServer", http_server):
"""
Media repo specific APIs.
"""
@@ -230,6 +263,7 @@ def register_servlets_for_media_repo(hs, http_server):
QuarantineMediaInRoom(hs).register(http_server)
QuarantineMediaByID(hs).register(http_server)
QuarantineMediaByUser(hs).register(http_server)
+ ProtectMediaByID(hs).register(http_server)
ListMediaInRoom(hs).register(http_server)
DeleteMediaByID(hs).register(http_server)
DeleteMediaByDateSize(hs).register(http_server)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..3e57e6a4d0 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,10 +14,10 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
-from synapse.api.constants import EventTypes, JoinRules
-from synapse.api.errors import Codes, NotFoundError, SynapseError
+from synapse.api.constants import EventTypes, JoinRules, Membership
+from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -25,13 +25,18 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -45,12 +50,14 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -86,13 +93,15 @@ class DeleteRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -146,12 +155,12 @@ class ListRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -236,19 +245,24 @@ class RoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
- return 200, ret
+ members = await self.store.get_users_in_room(room_id)
+ ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+ return (200, ret)
class RoomMembersRestServlet(RestServlet):
@@ -258,12 +272,14 @@ class RoomMembersRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
@@ -276,18 +292,59 @@ class RoomMembersRestServlet(RestServlet):
return 200, ret
+class RoomStateRestServlet(RestServlet):
+ """
+ Get full state within a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/state")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ ret = await self.store.get_room(room_id)
+ if not ret:
+ raise NotFoundError("Room not found")
+
+ event_ids = await self.store.get_current_state_ids(room_id)
+ events = await self.store.get_events(event_ids.values())
+ now = self.clock.time_msec()
+ room_state = await self._event_serializer.serialize_events(
+ events.values(),
+ now,
+ # We don't bother bundling aggregations in when asked for state
+ # events, as clients won't use them.
+ bundle_aggregations=False,
+ )
+ ret = {"state": room_state}
+
+ return 200, ret
+
+
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
- async def on_POST(self, request, room_identifier):
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -314,7 +371,6 @@ class JoinRoomAliasServlet(RestServlet):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
@@ -351,3 +407,201 @@ class JoinRoomAliasServlet(RestServlet):
)
return 200, {"room_id": room_id}
+
+
+class MakeRoomAdminRestServlet(RestServlet):
+ """Allows a server admin to get power in a room if a local user has power in
+ a room. Will also invite the user if they're not in the room and it's a
+ private room. Can specify another user (rather than the admin user) to be
+ granted power, e.g.:
+
+ POST/_synapse/admin/v1/rooms/<room_id_or_alias>/make_room_admin
+ {
+ "user_id": "@foo:example.com"
+ }
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.state_handler = hs.get_state_handler()
+ self.is_mine_id = hs.is_mine_id
+
+ async def on_POST(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ content = parse_json_object_from_request(request, allow_empty_body=True)
+
+ # Resolve to a room ID, if necessary.
+ if RoomID.is_valid(room_identifier):
+ room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+
+ # Which user to grant room admin rights to.
+ user_to_add = content.get("user_id", requester.user.to_string())
+
+ # Figure out which local users currently have power in the room, if any.
+ room_state = await self.state_handler.get_current_state(room_id)
+ if not room_state:
+ raise SynapseError(400, "Server not in room")
+
+ create_event = room_state[(EventTypes.Create, "")]
+ power_levels = room_state.get((EventTypes.PowerLevels, ""))
+
+ if power_levels is not None:
+ # We pick the local user with the highest power.
+ user_power = power_levels.content.get("users", {})
+ admin_users = [
+ user_id for user_id in user_power if self.is_mine_id(user_id)
+ ]
+ admin_users.sort(key=lambda user: user_power[user])
+
+ if not admin_users:
+ raise SynapseError(400, "No local admin user in room")
+
+ admin_user_id = None
+
+ for admin_user in reversed(admin_users):
+ if room_state.get((EventTypes.Member, admin_user)):
+ admin_user_id = admin_user
+ break
+
+ if not admin_user_id:
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
+
+ pl_content = power_levels.content
+ else:
+ # If there is no power level events then the creator has rights.
+ pl_content = {}
+ admin_user_id = create_event.sender
+ if not self.is_mine_id(admin_user_id):
+ raise SynapseError(
+ 400, "No local admin user in room",
+ )
+
+ # Grant the user power equal to the room admin by attempting to send an
+ # updated power level event.
+ new_pl_content = dict(pl_content)
+ new_pl_content["users"] = dict(pl_content.get("users", {}))
+ new_pl_content["users"][user_to_add] = new_pl_content["users"][admin_user_id]
+
+ fake_requester = create_requester(
+ admin_user_id, authenticated_entity=requester.authenticated_entity,
+ )
+
+ try:
+ await self.event_creation_handler.create_and_send_nonmember_event(
+ fake_requester,
+ event_dict={
+ "content": new_pl_content,
+ "sender": admin_user_id,
+ "type": EventTypes.PowerLevels,
+ "state_key": "",
+ "room_id": room_id,
+ },
+ )
+ except AuthError:
+ # The admin user we found turned out not to have enough power.
+ raise SynapseError(
+ 400, "No local admin user in room with power to update power levels."
+ )
+
+ # Now we check if the user we're granting admin rights to is already in
+ # the room. If not and it's not a public room we invite them.
+ member_event = room_state.get((EventTypes.Member, user_to_add))
+ is_joined = False
+ if member_event:
+ is_joined = member_event.content["membership"] in (
+ Membership.JOIN,
+ Membership.INVITE,
+ )
+
+ if is_joined:
+ return 200, {}
+
+ join_rules = room_state.get((EventTypes.JoinRules, ""))
+ is_public = False
+ if join_rules:
+ is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+
+ if is_public:
+ return 200, {}
+
+ await self.room_member_handler.update_membership(
+ fake_requester,
+ target=UserID.from_string(user_to_add),
+ room_id=room_id,
+ action=Membership.INVITE,
+ )
+
+ return 200, {}
+
+
+class ForwardExtremitiesRestServlet(RestServlet):
+ """Allows a server admin to get or clear forward extremities.
+
+ Clearing does not require restarting the server.
+
+ Clear forward extremities:
+ DELETE /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+
+ Get forward_extremities:
+ GET /_synapse/admin/v1/rooms/<room_id_or_alias>/forward_extremities
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.room_member_handler = hs.get_room_member_handler()
+ self.store = hs.get_datastore()
+
+ async def resolve_room_id(self, room_identifier: str) -> str:
+ """Resolve to a room ID, if necessary."""
+ if RoomID.is_valid(room_identifier):
+ resolved_room_id = room_identifier
+ elif RoomAlias.is_valid(room_identifier):
+ room_alias = RoomAlias.from_string(room_identifier)
+ room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
+ resolved_room_id = room_id.to_string()
+ else:
+ raise SynapseError(
+ 400, "%s was not legal room ID or room alias" % (room_identifier,)
+ )
+ if not resolved_room_id:
+ raise SynapseError(
+ 400, "Unknown room ID or room alias %s" % room_identifier
+ )
+ return resolved_room_id
+
+ async def on_DELETE(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
+ return 200, {"deleted": deleted_count}
+
+ async def on_GET(self, request, room_identifier):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ room_id = await self.resolve_room_id(room_identifier)
+
+ extremities = await self.store.get_forward_extremities_for_room(room_id)
+ return 200, {"count": len(extremities), "results": extremities}
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..68c3c64a0d 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-_GET_PUSHERS_ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -94,17 +83,32 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100)
+
+ if start < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter from must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
+ if limit < 0:
+ raise SynapseError(
+ 400,
+ "Query parameter limit must be a string representing a positive integer.",
+ errcode=Codes.INVALID_PARAM,
+ )
+
user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True)
@@ -114,7 +118,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated
)
ret = {"users": users, "total": total}
- if len(users) >= limit:
+ if (start + limit) < total:
ret["next_token"] = str(start + len(users))
return 200, ret
@@ -255,7 +259,7 @@ class UserRestServletV2(RestServlet):
if deactivate and not user["deactivated"]:
await self.deactivate_account_handler.deactivate_account(
- target_user.to_string(), False
+ target_user.to_string(), False, requester, by_admin=True
)
elif not deactivate and user["deactivated"]:
if "password" not in body:
@@ -320,9 +324,9 @@ class UserRestServletV2(RestServlet):
data={},
)
- if "avatar_url" in body and type(body["avatar_url"]) == str:
+ if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
- user_id, requester, body["avatar_url"], True
+ target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +424,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
+ if "mac" not in body:
+ raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
got_mac = body["mac"]
want_mac_builder = hmac.new(
@@ -494,12 +501,22 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
self.auth = hs.get_auth()
+ self.is_mine = hs.is_mine
+ self.store = hs.get_datastore()
+
+ async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+
+ if not self.is_mine(UserID.from_string(target_user_id)):
+ raise SynapseError(400, "Can only deactivate local users")
+
+ if not await self.store.get_user_by_id(target_user_id):
+ raise NotFoundError("User not found")
- async def on_POST(self, request, target_user_id):
- await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request, allow_empty_body=True)
erase = body.get("erase", False)
if not isinstance(erase, bool):
@@ -509,10 +526,8 @@ class DeactivateAccountRestServlet(RestServlet):
Codes.BAD_JSON,
)
- UserID.from_string(target_user_id)
-
result = await self._deactivate_account_handler.deactivate_account(
- target_user_id, erase
+ target_user_id, erase, requester, by_admin=True
)
if result:
id_server_unbind_result = "success"
@@ -722,13 +737,6 @@ class UserMembershipRestServlet(RestServlet):
async def on_GET(self, request, user_id):
await assert_requester_is_admin(self.auth, request)
- if not self.is_mine(UserID.from_string(user_id)):
- raise SynapseError(400, "Can only lookup local users")
-
- user = await self.store.get_user_by_id(user_id)
- if user is None:
- raise NotFoundError("Unknown user")
-
room_ids = await self.store.get_rooms_for_user(user_id)
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret
@@ -767,10 +775,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.store.get_pushers_by_user_id(user_id)
- filtered_pushers = [
- {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
- for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
@@ -885,3 +890,39 @@ class UserTokenRestServlet(RestServlet):
)
return 200, {"access_token": token}
+
+
+class ShadowBanRestServlet(RestServlet):
+ """An admin API for shadow-banning a user.
+
+ A shadow-banned users receives successful responses to their client-server
+ API requests, but the events are not propagated into rooms.
+
+ Shadow-banning a user should be used as a tool of last resort and may lead
+ to confusing or broken behaviour for the client.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/shadow_ban
+ {}
+
+ 200 OK
+ {}
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/shadow_ban")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+
+ async def on_POST(self, request, user_id):
+ await assert_requester_is_admin(self.auth, request)
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be shadow-banned")
+
+ await self.store.set_shadow_banned(UserID.from_string(user_id), True)
+
+ return 200, {}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..0fb9419e58 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,12 +14,13 @@
# limitations under the License.
import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -30,6 +31,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -42,7 +46,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -57,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -86,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+
+ if self._msc2858_enabled:
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -105,22 +121,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
- self._address_ratelimiter.ratelimit(request.getClientIP())
-
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
+
+ if appservice.is_rate_limited():
+ self._address_ratelimiter.ratelimit(request.getClientIP())
+
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +180,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
- return await self._complete_login(qualified_user_id, login_submission)
+ return await self._complete_login(
+ qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ )
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +217,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
+ ratelimit: bool = True,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +232,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
+ ratelimit: Whether to ratelimit the login request.
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +241,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(user_id.lower())
+ if ratelimit:
+ self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
@@ -298,48 +324,63 @@ class LoginRestServlet(RestServlet):
return result
-class BaseSSORedirectServlet(RestServlet):
- """Common base class for /login/sso/redirect impls"""
-
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ if idp.idp_brand:
+ e["brand"] = idp.idp_brand
+ return e
+
+
+class SsoRedirectServlet(RestServlet):
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
+
+ def __init__(self, hs: "HomeServer"):
+ # make sure that the relevant handlers are instantiated, so that they
+ # register themselves with the main SSOHandler.
+ if hs.config.cas_enabled:
+ hs.get_cas_handler()
+ if hs.config.saml2_enabled:
+ hs.get_saml_handler()
+ if hs.config.oidc_enabled:
+ hs.get_oidc_handler()
+ self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- async def on_GET(self, request: SynapseRequest):
- args = request.args
- if b"redirectUrl" not in args:
- return 400, "Redirect URL not specified for SSO auth"
- client_redirect_url = args[b"redirectUrl"][0]
- sso_url = await self.get_sso_url(request, client_redirect_url)
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
+ client_redirect_url = parse_string(
+ request, "redirectUrl", required=True, encoding=None
+ )
+ sso_url = await self._sso_handler.handle_redirect_request(
+ request, client_redirect_url, idp_id,
+ )
+ logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
finish_request(request)
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- """Get the URL to redirect to, to perform SSO auth
-
- Args:
- request: The client request to redirect.
- client_redirect_url: the URL that we should redirect the
- client to when everything is done
-
- Returns:
- URL to redirect to
- """
- # to be implemented by subclasses
- raise NotImplementedError()
-
-
-class CasRedirectServlet(BaseSSORedirectServlet):
- def __init__(self, hs):
- self._cas_handler = hs.get_cas_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return self._cas_handler.get_redirect_url(
- {"redirectUrl": client_redirect_url}
- ).encode("ascii")
-
class CasTicketServlet(RestServlet):
PATTERNS = client_patterns("/login/cas/ticket", v1=True)
@@ -366,40 +407,8 @@ class CasTicketServlet(RestServlet):
)
-class SAMLRedirectServlet(BaseSSORedirectServlet):
- PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
- def __init__(self, hs):
- self._saml_handler = hs.get_saml_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return self._saml_handler.handle_redirect_request(client_redirect_url)
-
-
-class OIDCRedirectServlet(BaseSSORedirectServlet):
- """Implementation for /login/sso/redirect for the OIDC login flow."""
-
- PATTERNS = client_patterns("/login/sso/redirect", v1=True)
-
- def __init__(self, hs):
- self._oidc_handler = hs.get_oidc_handler()
-
- async def get_sso_url(
- self, request: SynapseRequest, client_redirect_url: bytes
- ) -> bytes:
- return await self._oidc_handler.handle_redirect_request(
- request, client_redirect_url
- )
-
-
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ SsoRedirectServlet(hs).register(http_server)
if hs.config.cas_enabled:
- CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
- elif hs.config.saml2_enabled:
- SAMLRedirectServlet(hs).register(http_server)
- elif hs.config.oidc_enabled:
- OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 8fe83f321a..89823fcc39 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
logger = logging.getLogger(__name__)
-ALLOWED_KEYS = {
- "app_display_name",
- "app_id",
- "data",
- "device_display_name",
- "kind",
- "lang",
- "profile_tag",
- "pushkey",
-}
-
class PushersRestServlet(RestServlet):
PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
- filtered_pushers = [
- {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
- ]
+ filtered_pushers = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 93c06afe27..f95627ee61 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -46,7 +46,7 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
-from synapse.util.stringutils import random_string
+from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
import synapse.server
@@ -347,8 +347,6 @@ class PublicRoomListRestServlet(TransactionRestServlet):
# provided.
if server:
raise e
- else:
- pass
limit = parse_integer(request, "limit", 0)
since_token = parse_string(request, "since", None)
@@ -359,6 +357,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
+ # Ensure the server is valid.
+ try:
+ parse_and_validate_server_name(server)
+ except ValueError:
+ raise SynapseError(
+ 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ )
+
try:
data = await handler.get_remote_public_room_list(
server, limit=limit, since_token=since_token
@@ -402,6 +408,14 @@ class PublicRoomListRestServlet(TransactionRestServlet):
handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server_name:
+ # Ensure the server is valid.
+ try:
+ parse_and_validate_server_name(server)
+ except ValueError:
+ raise SynapseError(
+ 400, "Invalid server name: %s" % (server,), Codes.INVALID_PARAM,
+ )
+
try:
data = await handler.get_remote_public_room_list(
server,
@@ -963,25 +977,28 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False):
)
-def register_servlets(hs, http_server):
+def register_servlets(hs, http_server, is_worker=False):
RoomStateEventRestServlet(hs).register(http_server)
- RoomCreateRestServlet(hs).register(http_server)
RoomMemberListRestServlet(hs).register(http_server)
JoinedRoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
- RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)
RoomStateRestServlet(hs).register(http_server)
RoomRedactEventRestServlet(hs).register(http_server)
RoomTypingRestServlet(hs).register(http_server)
- SearchRestServlet(hs).register(http_server)
- JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventServlet(hs).register(http_server)
RoomEventContextServlet(hs).register(http_server)
- RoomAliasListServlet(hs).register(http_server)
+
+ # Some servlets only get registered for the main process.
+ if not is_worker:
+ RoomCreateRestServlet(hs).register(http_server)
+ RoomForgetRestServlet(hs).register(http_server)
+ SearchRestServlet(hs).register(http_server)
+ JoinedRoomsRestServlet(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomAliasListServlet(hs).register(http_server)
def register_deprecated_servlets(hs, http_server):
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index e0feebea94..b67c1702ca 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -20,9 +20,6 @@ from http import HTTPStatus
from typing import TYPE_CHECKING
from urllib.parse import urlparse
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
from synapse.api.constants import LoginType
from synapse.api.errors import (
Codes,
@@ -31,6 +28,7 @@ from synapse.api.errors import (
ThreepidValidationError,
)
from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -46,13 +44,17 @@ from synapse.util.threepids import canonicalise_email, check_3pid_allowed
from ._base import client_patterns, interactive_auth_handler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
logger = logging.getLogger(__name__)
class EmailPasswordRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/password/email/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.datastore = hs.get_datastore()
@@ -101,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
# The email will be sent to the stored address.
# This avoids a potential account hijack by requesting a password reset to
# an email address which is controlled by the attacker but which, after
@@ -189,11 +193,7 @@ class PasswordRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "modify your account password",
+ requester, request, body, "modify your account password",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth, but
@@ -204,7 +204,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
user_id = requester.user.to_string()
@@ -215,7 +217,6 @@ class PasswordRestServlet(RestServlet):
[[LoginType.EMAIL_IDENTITY]],
request,
body,
- self.hs.get_ip_from_request(request),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -227,7 +228,9 @@ class PasswordRestServlet(RestServlet):
if new_password:
password_hash = await self.auth_handler.hash(new_password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -254,14 +257,18 @@ class PasswordRestServlet(RestServlet):
logger.error("Auth succeeded but no known type! %r", result.keys())
raise SynapseError(500, "", Codes.UNKNOWN)
- # If we have a password in this request, prefer it. Otherwise, there
- # must be a password hash from an earlier request.
+ # If we have a password in this request, prefer it. Otherwise, use the
+ # password hash from an earlier request.
if new_password:
password_hash = await self.auth_handler.hash(new_password)
- else:
+ elif session_id is not None:
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
+ else:
+ # UI validation was skipped, but the request did not include a new
+ # password.
+ password_hash = None
if not password_hash:
raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM)
@@ -300,19 +307,18 @@ class DeactivateAccountRestServlet(RestServlet):
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase
+ requester.user.to_string(), erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "deactivate your account",
+ requester, request, body, "deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase, id_server=body.get("id_server")
+ requester.user.to_string(),
+ erase,
+ requester,
+ id_server=body.get("id_server"),
)
if result:
id_server_unbind_result = "success"
@@ -375,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -426,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
super().__init__()
self.store = self.hs.get_datastore()
@@ -454,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
if next_link:
# Raise if the provided next_link value isn't valid
assert_valid_next_link(self.hs, next_link)
@@ -691,11 +703,7 @@ class ThreepidAddRestServlet(RestServlet):
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "add a third-party identifier to your account",
+ requester, request, body, "add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 87a5b1b86b..3f28c0bc3e 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
body = parse_json_object_from_request(request)
- max_id = await self.store.add_account_data_for_user(
- user_id, account_data_type, body
- )
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_account_data_for_user(user_id, account_data_type, body)
return 200, {}
@@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
- self._is_worker = hs.config.worker_app is not None
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, account_data_type):
- if self._is_worker:
- raise Exception("Cannot handle PUT /account_data on worker")
-
requester = await self.auth.get_user_by_req(request)
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add account data for other users.")
@@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet):
" Use /rooms/!roomId:server.name/read_markers",
)
- max_id = await self.store.add_account_data_to_room(
+ await self.handler.add_account_data_to_room(
user_id, room_id, account_data_type, body
)
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
-
return 200, {}
async def on_GET(self, request, user_id, room_id, account_data_type):
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
from ._base import client_patterns
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
-
- # SSO configuration.
- self._cas_enabled = hs.config.cas_enabled
- if self._cas_enabled:
- self._cas_handler = hs.get_cas_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
- self._saml_enabled = hs.config.saml2_enabled
- if self._saml_enabled:
- self._saml_handler = hs.get_saml_handler()
- self._oidc_enabled = hs.config.oidc_enabled
- if self._oidc_enabled:
- self._oidc_handler = hs.get_oidc_handler()
- self._cas_server_url = hs.config.cas_server_url
- self._cas_service_url = hs.config.cas_service_url
-
self.recaptcha_template = hs.config.recaptcha_template
self.terms_template = hs.config.terms_template
self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
elif stagetype == LoginType.SSO:
# Display a confirmation page which prompts the user to
# re-authenticate with their SSO provider.
- if self._cas_enabled:
- # Generate a request to CAS that redirects back to an endpoint
- # to verify the successful authentication.
- sso_redirect_url = self._cas_handler.get_redirect_url(
- {"session": session},
- )
-
- elif self._saml_enabled:
- # Some SAML identity providers (e.g. Google) require a
- # RelayState parameter on requests. It is not necessary here, so
- # pass in a dummy redirect URL (which will never get used).
- client_redirect_url = b"unused"
- sso_redirect_url = self._saml_handler.handle_redirect_request(
- client_redirect_url, session
- )
-
- elif self._oidc_enabled:
- client_redirect_url = b""
- sso_redirect_url = await self._oidc_handler.handle_redirect_request(
- request, client_redirect_url, session
- )
-
- else:
- raise SynapseError(400, "Homeserver not configured for SSO.")
-
- html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+ html = await self.auth_handler.start_sso_ui_auth(request, session)
else:
raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
authdict = {"response": response, "session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+ LoginType.RECAPTCHA, authdict, request.getClientIP()
)
if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
authdict = {"session": session}
success = await self.auth_handler.add_oob_auth(
- LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+ LoginType.TERMS, authdict, request.getClientIP()
)
if success:
diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index af117cb27c..314e01dfe4 100644
--- a/synapse/rest/client/v2_alpha/devices.py
+++ b/synapse/rest/client/v2_alpha/devices.py
@@ -83,11 +83,7 @@ class DeleteDevicesRestServlet(RestServlet):
assert_params_in_dict(body, ["devices"])
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "remove device(s) from your account",
+ requester, request, body, "remove device(s) from your account",
)
await self.device_handler.delete_devices(
@@ -133,11 +129,7 @@ class DeviceRestServlet(RestServlet):
raise
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "remove a device from your account",
+ requester, request, body, "remove a device from your account",
)
await self.device_handler.delete_device(requester.user.to_string(), device_id)
diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index 75215a3779..28b55f27ad 100644
--- a/synapse/rest/client/v2_alpha/groups.py
+++ b/synapse/rest/client/v2_alpha/groups.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from functools import wraps
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
@@ -25,6 +26,22 @@ from ._base import client_patterns
logger = logging.getLogger(__name__)
+def _validate_group_id(f):
+ """Wrapper to validate the form of the group ID.
+
+ Can be applied to any on_FOO methods that accepts a group ID as a URL parameter.
+ """
+
+ @wraps(f)
+ def wrapper(self, request, group_id, *args, **kwargs):
+ if not GroupID.is_valid(group_id):
+ raise SynapseError(400, "%s is not a legal group ID" % (group_id,))
+
+ return f(self, request, group_id, *args, **kwargs)
+
+ return wrapper
+
+
class GroupServlet(RestServlet):
"""Get the group profile
"""
@@ -37,6 +54,7 @@ class GroupServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -47,6 +65,7 @@ class GroupServlet(RestServlet):
return 200, group_description
+ @_validate_group_id
async def on_POST(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -71,6 +90,7 @@ class GroupSummaryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -102,6 +122,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -117,6 +138,7 @@ class GroupSummaryRoomsCatServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -142,6 +164,7 @@ class GroupCategoryServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -152,6 +175,7 @@ class GroupCategoryServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -163,6 +187,7 @@ class GroupCategoryServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, category_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -186,6 +211,7 @@ class GroupCategoriesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -209,6 +235,7 @@ class GroupRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -219,6 +246,7 @@ class GroupRoleServlet(RestServlet):
return 200, category
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -230,6 +258,7 @@ class GroupRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -253,6 +282,7 @@ class GroupRolesServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -284,6 +314,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -299,6 +330,7 @@ class GroupSummaryUsersRoleServlet(RestServlet):
return 200, resp
+ @_validate_group_id
async def on_DELETE(self, request, group_id, role_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -322,13 +354,11 @@ class GroupRoomServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
- if not GroupID.is_valid(group_id):
- raise SynapseError(400, "%s was not legal group ID" % (group_id,))
-
result = await self.groups_handler.get_rooms_in_group(
group_id, requester_user_id
)
@@ -348,6 +378,7 @@ class GroupUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
requester_user_id = requester.user.to_string()
@@ -371,6 +402,7 @@ class GroupInvitedUsersServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_GET(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -393,6 +425,7 @@ class GroupSettingJoinPolicyServlet(RestServlet):
self.auth = hs.get_auth()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -449,6 +482,7 @@ class GroupAdminRoomsServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -460,6 +494,7 @@ class GroupAdminRoomsServlet(RestServlet):
return 200, result
+ @_validate_group_id
async def on_DELETE(self, request, group_id, room_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -486,6 +521,7 @@ class GroupAdminRoomsConfigServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, room_id, config_key):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -514,6 +550,7 @@ class GroupAdminUsersInviteServlet(RestServlet):
self.store = hs.get_datastore()
self.is_mine_id = hs.is_mine_id
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -541,6 +578,7 @@ class GroupAdminUsersKickServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -590,6 +628,7 @@ class GroupSelfLeaveServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -614,6 +653,7 @@ class GroupSelfJoinServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -638,6 +678,7 @@ class GroupSelfAcceptInviteServlet(RestServlet):
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
@@ -662,6 +703,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastore()
+ @_validate_group_id
async def on_PUT(self, request, group_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index b91996c738..a6134ead8a 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -271,11 +271,7 @@ class SigningKeyUploadServlet(RestServlet):
body = parse_json_object_from_request(request)
await self.auth_handler.validate_user_via_ui_auth(
- requester,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "add a device signing key to your account",
+ requester, request, body, "add a device signing key to your account",
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 5374d2c1b6..f0675abd32 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -38,6 +38,7 @@ from synapse.config.ratelimiting import FederationRateLimitConfig
from synapse.config.registration import RegistrationConfig
from synapse.config.server import is_threepid_reserved
from synapse.handlers.auth import AuthHandler
+from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
@@ -125,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(request, "email", email)
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"email", email
)
@@ -204,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
Codes.THREEPID_DENIED,
)
+ self.identity_handler.ratelimit_request_token_requests(
+ request, "msisdn", msisdn
+ )
+
existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid(
"msisdn", msisdn
)
@@ -353,7 +360,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
403, "Registration has been disabled", errcode=Codes.FORBIDDEN
)
- ip = self.hs.get_ip_from_request(request)
+ ip = request.getClientIP()
with self.ratelimiter.ratelimit(ip) as wait_deferred:
await wait_deferred
@@ -451,7 +458,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
@@ -494,11 +501,11 @@ class RegisterRestServlet(RestServlet):
# user here. We carry on and go through the auth checks though,
# for paranoia.
registered_user_id = await self.auth_handler.get_session_data(
- session_id, "registered_user_id", None
+ session_id, UIAuthSessionDataConstants.REGISTERED_USER_ID, None
)
# Extract the previously-hashed password from the session.
password_hash = await self.auth_handler.get_session_data(
- session_id, "password_hash", None
+ session_id, UIAuthSessionDataConstants.PASSWORD_HASH, None
)
# Ensure that the username is valid.
@@ -513,11 +520,7 @@ class RegisterRestServlet(RestServlet):
# not this will raise a user-interactive auth error.
try:
auth_result, params, session_id = await self.auth_handler.check_ui_auth(
- self._registration_flows,
- request,
- body,
- self.hs.get_ip_from_request(request),
- "register a new account",
+ self._registration_flows, request, body, "register a new account",
)
except InteractiveAuthIncompleteError as e:
# The user needs to provide more steps to complete auth.
@@ -532,7 +535,9 @@ class RegisterRestServlet(RestServlet):
if not password_hash and password:
password_hash = await self.auth_handler.hash(password)
await self.auth_handler.set_session_data(
- e.session_id, "password_hash", password_hash
+ e.session_id,
+ UIAuthSessionDataConstants.PASSWORD_HASH,
+ password_hash,
)
raise
@@ -635,7 +640,9 @@ class RegisterRestServlet(RestServlet):
# Remember that the user account has been registered (and the user
# ID it was registered with, since it might not have been specified).
await self.auth_handler.set_session_data(
- session_id, "registered_user_id", registered_user_id
+ session_id,
+ UIAuthSessionDataConstants.REGISTERED_USER_ID,
+ registered_user_id,
)
registered = True
@@ -657,9 +664,13 @@ class RegisterRestServlet(RestServlet):
user_id = await self.registration_handler.appservice_register(
username, as_token
)
- return await self._create_registration_details(user_id, body)
+ return await self._create_registration_details(
+ user_id, body, is_appservice_ghost=True,
+ )
- async def _create_registration_details(self, user_id, params):
+ async def _create_registration_details(
+ self, user_id, params, is_appservice_ghost=False
+ ):
"""Complete registration of newly-registered user
Allocates device_id if one was not given; also creates access_token.
@@ -676,7 +687,11 @@ class RegisterRestServlet(RestServlet):
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name, is_guest=False
+ user_id,
+ device_id,
+ initial_display_name,
+ is_guest=False,
+ is_appservice_ghost=is_appservice_ghost,
)
result.update({"access_token": access_token, "device_id": device_id})
diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index bc4f43639a..a3dee14ed4 100644
--- a/synapse/rest/client/v2_alpha/sendtodevice.py
+++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -17,7 +17,7 @@ import logging
from typing import Tuple
from synapse.http import servlet
-from synapse.http.servlet import parse_json_object_from_request
+from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.logging.opentracing import set_tag, trace
from synapse.rest.client.transactions import HttpTransactionCache
@@ -54,6 +54,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
+ assert_params_in_dict(content, ("messages",))
sender_user_id = requester.user.to_string()
diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index bf3a79db44..a97cd66c52 100644
--- a/synapse/rest/client/v2_alpha/tags.py
+++ b/synapse/rest/client/v2_alpha/tags.py
@@ -58,8 +58,7 @@ class TagServlet(RestServlet):
def __init__(self, hs):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastore()
- self.notifier = hs.get_notifier()
+ self.handler = hs.get_account_data_handler()
async def on_PUT(self, request, user_id, room_id, tag):
requester = await self.auth.get_user_by_req(request)
@@ -68,9 +67,7 @@ class TagServlet(RestServlet):
body = parse_json_object_from_request(request)
- max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.add_tag_to_room(user_id, room_id, tag, body)
return 200, {}
@@ -79,9 +76,7 @@ class TagServlet(RestServlet):
if user_id != requester.user.to_string():
raise AuthError(403, "Cannot add tags for other users.")
- max_id = await self.store.remove_tag_from_room(user_id, room_id, tag)
-
- self.notifier.on_new_event("account_data_key", max_id, users=[user_id])
+ await self.handler.remove_tag_from_room(user_id, room_id, tag)
return 200, {}
diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index b3e4d5612e..8b9ef26cf2 100644
--- a/synapse/rest/consent/consent_resource.py
+++ b/synapse/rest/consent/consent_resource.py
@@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource):
consent_template_directory = hs.config.user_consent_template_dir
+ # TODO: switch to synapse.util.templates.build_jinja_env
loader = jinja2.FileSystemLoader(consent_template_directory)
self._jinja_env = jinja2.Environment(
loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"])
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f843f02454..c57ac22e58 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import Dict, Set
+from typing import Dict
from signedjson.sign import sign_json
@@ -142,12 +142,13 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- cache_misses = {} # type: Dict[str, Set[str]]
+ # Note that the value is unused.
+ cache_misses = {} # type: Dict[str, Dict[str, int]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]
if not results and key_id is not None:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
continue
if key_id is not None:
@@ -201,7 +202,7 @@ class RemoteKey(DirectServeJsonResource):
)
if miss:
- cache_misses.setdefault(server_name, set()).add(key_id)
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
else:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..f71a03a12d 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2019 New Vector Ltd
+# Copyright 2019-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,10 +17,11 @@
import logging
import os
import urllib
-from typing import Awaitable
+from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@@ -46,7 +47,7 @@ TEXT_CONTENT_TYPES = [
]
-def parse_media_id(request):
+def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
@@ -69,7 +70,7 @@ def parse_media_id(request):
)
-def respond_404(request):
+def respond_404(request: Request) -> None:
respond_with_json(
request,
404,
@@ -79,8 +80,12 @@ def respond_404(request):
async def respond_with_file(
- request, media_type, file_path, file_size=None, upload_name=None
-):
+ request: Request,
+ media_type: str,
+ file_path: str,
+ file_size: Optional[int] = None,
+ upload_name: Optional[str] = None,
+) -> None:
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
@@ -98,15 +103,20 @@ async def respond_with_file(
respond_404(request)
-def add_file_headers(request, media_type, file_size, upload_name):
+def add_file_headers(
+ request: Request,
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str],
+) -> None:
"""Adds the correct response headers in preparation for responding with the
media.
Args:
- request (twisted.web.http.Request)
- media_type (str): The media/content type.
- file_size (int): Size in bytes of the media, if known.
- upload_name (str): The name of the requested file, if any.
+ request
+ media_type: The media/content type.
+ file_size: Size in bytes of the media, if known.
+ upload_name: The name of the requested file, if any.
"""
def _quote(x):
@@ -153,7 +163,13 @@ def add_file_headers(request, media_type, file_size, upload_name):
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
- request.setHeader(b"Content-Length", b"%d" % (file_size,))
+ if file_size is not None:
+ request.setHeader(b"Content-Length", b"%d" % (file_size,))
+
+ # Tell web crawlers to not index, archive, or follow links in media. This
+ # should help to prevent things in the media repo from showing up in web
+ # search results.
+ request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
# separators as defined in RFC2616. SP and HT are handled separately.
@@ -179,7 +195,7 @@ _FILENAME_SEPARATOR_CHARS = {
}
-def _can_encode_filename_as_token(x):
+def _can_encode_filename_as_token(x: str) -> bool:
for c in x:
# from RFC2616:
#
@@ -201,17 +217,21 @@ def _can_encode_filename_as_token(x):
async def respond_with_responder(
- request, responder, media_type, file_size, upload_name=None
-):
+ request: Request,
+ responder: "Optional[Responder]",
+ media_type: str,
+ file_size: Optional[int],
+ upload_name: Optional[str] = None,
+) -> None:
"""Responds to the request with given responder. If responder is None then
returns 404.
Args:
- request (twisted.web.http.Request)
- responder (Responder|None)
- media_type (str): The media/content type.
- file_size (int|None): Size in bytes of the media. If not known it should be None
- upload_name (str|None): The name of the requested file, if any.
+ request
+ responder
+ media_type: The media/content type.
+ file_size: Size in bytes of the media. If not known it should be None
+ upload_name: The name of the requested file, if any.
"""
if request._disconnected:
logger.warning(
@@ -280,6 +300,7 @@ class FileInfo:
thumbnail_height (int)
thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ thumbnail_length (int): The size of the media file, in bytes.
"""
def __init__(
@@ -292,6 +313,7 @@ class FileInfo:
thumbnail_height=None,
thumbnail_method=None,
thumbnail_type=None,
+ thumbnail_length=None,
):
self.server_name = server_name
self.file_id = file_id
@@ -301,24 +323,25 @@ class FileInfo:
self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type
+ self.thumbnail_length = thumbnail_length
-def get_filename_from_headers(headers):
+def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:
"""
Get the filename of the downloaded file by inspecting the
Content-Disposition HTTP header.
Args:
- headers (dict[bytes, list[bytes]]): The HTTP request headers.
+ headers: The HTTP request headers.
Returns:
- A Unicode string of the filename, or None.
+ The filename, or None.
"""
content_disposition = headers.get(b"Content-Disposition", [b""])
# No header, bail out.
if not content_disposition[0]:
- return
+ return None
_, params = _parse_header(content_disposition[0])
@@ -351,17 +374,16 @@ def get_filename_from_headers(headers):
return upload_name
-def _parse_header(line):
+def _parse_header(line: bytes) -> Tuple[bytes, Dict[bytes, bytes]]:
"""Parse a Content-type like header.
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- line (bytes): header to be parsed
+ line: header to be parsed
Returns:
- Tuple[bytes, dict[bytes, bytes]]:
- the main content-type, followed by the parameter dictionary
+ The main content-type, followed by the parameter dictionary
"""
parts = _parseparam(b";" + line)
key = next(parts)
@@ -381,16 +403,16 @@ def _parse_header(line):
return key, pdict
-def _parseparam(s):
+def _parseparam(s: bytes) -> Generator[bytes, None, None]:
"""Generator which splits the input on ;, respecting double-quoted sequences
Cargo-culted from `cgi`, but works on bytes rather than strings.
Args:
- s (bytes): header to be parsed
+ s: header to be parsed
Returns:
- Iterable[bytes]: the split input
+ The split input
"""
while s[:1] == b";":
s = s[1:]
diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py
index 68dd2a1c8a..4e4c6971f7 100644
--- a/synapse/rest/media/v1/config_resource.py
+++ b/synapse/rest/media/v1/config_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2018 Will Hunt <will@half-shot.uk>
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,22 +15,29 @@
# limitations under the License.
#
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
from synapse.http.server import DirectServeJsonResource, respond_with_json
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
class MediaConfigResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
config = hs.get_config()
self.clock = hs.get_clock()
self.auth = hs.get_auth()
self.limits_dict = {"m.upload.size": config.max_upload_size}
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
await self.auth.get_user_by_req(request)
respond_with_json(request, 200, self.limits_dict, send_cors=True)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index d3d8457303..3ed219ae43 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,24 +14,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
-import synapse.http.servlet
from synapse.http.server import DirectServeJsonResource, set_cors_headers
+from synapse.http.servlet import parse_boolean
from ._base import parse_media_id, respond_404
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class DownloadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
request.setHeader(
b"Content-Security-Policy",
@@ -49,9 +57,7 @@ class DownloadResource(DirectServeJsonResource):
if server_name == self.server_name:
await self.media_repo.get_local_media(request, media_id, name)
else:
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True
- )
+ allow_remote = parse_boolean(request, "allow_remote", default=True)
if not allow_remote:
logger.info(
"Rejecting request for remote media %s/%s due to allow_remote",
diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index 9e079f672f..7792f26e78 100644
--- a/synapse/rest/media/v1/filepath.py
+++ b/synapse/rest/media/v1/filepath.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,11 +17,12 @@
import functools
import os
import re
+from typing import Callable, List
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
-def _wrap_in_base_path(func):
+def _wrap_in_base_path(func: "Callable[..., str]") -> "Callable[..., str]":
"""Takes a function that returns a relative path and turns it into an
absolute path based on the location of the primary media store
"""
@@ -41,12 +43,18 @@ class MediaFilePaths:
to write to the backup media store (when one is configured)
"""
- def __init__(self, primary_base_path):
+ def __init__(self, primary_base_path: str):
self.base_path = primary_base_path
def default_thumbnail_rel(
- self, default_top_level, default_sub_type, width, height, content_type, method
- ):
+ self,
+ default_top_level: str,
+ default_sub_type: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -55,12 +63,14 @@ class MediaFilePaths:
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
- def local_media_filepath_rel(self, media_id):
+ def local_media_filepath_rel(self, media_id: str) -> str:
return os.path.join("local_content", media_id[0:2], media_id[2:4], media_id[4:])
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
- def local_media_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def local_media_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -86,7 +96,7 @@ class MediaFilePaths:
media_id[4:],
)
- def remote_media_filepath_rel(self, server_name, file_id):
+ def remote_media_filepath_rel(self, server_name: str, file_id: str) -> str:
return os.path.join(
"remote_content", server_name, file_id[0:2], file_id[2:4], file_id[4:]
)
@@ -94,8 +104,14 @@ class MediaFilePaths:
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
def remote_media_thumbnail_rel(
- self, server_name, file_id, width, height, content_type, method
- ):
+ self,
+ server_name: str,
+ file_id: str,
+ width: int,
+ height: int,
+ content_type: str,
+ method: str,
+ ) -> str:
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method)
return os.path.join(
@@ -113,7 +129,7 @@ class MediaFilePaths:
# Should be removed after some time, when most of the thumbnails are stored
# using the new path.
def remote_media_thumbnail_rel_legacy(
- self, server_name, file_id, width, height, content_type
+ self, server_name: str, file_id: str, width: int, height: int, content_type: str
):
top_level_type, sub_type = content_type.split("/")
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
@@ -126,7 +142,7 @@ class MediaFilePaths:
file_name,
)
- def remote_media_thumbnail_dir(self, server_name, file_id):
+ def remote_media_thumbnail_dir(self, server_name: str, file_id: str) -> str:
return os.path.join(
self.base_path,
"remote_thumbnail",
@@ -136,7 +152,7 @@ class MediaFilePaths:
file_id[4:],
)
- def url_cache_filepath_rel(self, media_id):
+ def url_cache_filepath_rel(self, media_id: str) -> str:
if NEW_FORMAT_ID_RE.match(media_id):
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -146,7 +162,7 @@ class MediaFilePaths:
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
- def url_cache_filepath_dirs_to_delete(self, media_id):
+ def url_cache_filepath_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id file"
if NEW_FORMAT_ID_RE.match(media_id):
return [os.path.join(self.base_path, "url_cache", media_id[:10])]
@@ -156,7 +172,9 @@ class MediaFilePaths:
os.path.join(self.base_path, "url_cache", media_id[0:2]),
]
- def url_cache_thumbnail_rel(self, media_id, width, height, content_type, method):
+ def url_cache_thumbnail_rel(
+ self, media_id: str, width: int, height: int, content_type: str, method: str
+ ) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -178,7 +196,7 @@ class MediaFilePaths:
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
- def url_cache_thumbnail_directory(self, media_id):
+ def url_cache_thumbnail_directory(self, media_id: str) -> str:
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
@@ -195,7 +213,7 @@ class MediaFilePaths:
media_id[4:],
)
- def url_cache_thumbnail_dirs_to_delete(self, media_id):
+ def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]:
"The dirs to try and remove if we delete the media_id thumbnails"
# Media id is of the form <DATE><RANDOM_STRING>
# E.g.: 2017-09-28-fsdRDt24DS234dsf
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..4c9946a616 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,12 +13,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import errno
import logging
import os
import shutil
-from typing import IO, Dict, List, Optional, Tuple
+from io import BytesIO
+from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
@@ -56,6 +56,9 @@ from .thumbnail_resource import ThumbnailResource
from .thumbnailer import Thumbnailer, ThumbnailError
from .upload_resource import UploadResource
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -63,26 +66,26 @@ UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
- self.primary_base_path = hs.config.media_store_path
- self.filepaths = MediaFilePaths(self.primary_base_path)
+ self.primary_base_path = hs.config.media_store_path # type: str
+ self.filepaths = MediaFilePaths(self.primary_base_path) # type: MediaFilePaths
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
- self.recently_accessed_remotes = set()
- self.recently_accessed_locals = set()
+ self.recently_accessed_remotes = set() # type: Set[Tuple[str, str]]
+ self.recently_accessed_locals = set() # type: Set[str]
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
@@ -113,7 +116,7 @@ class MediaRepository:
"update_recently_accessed_media", self._update_recently_accessed
)
- async def _update_recently_accessed(self):
+ async def _update_recently_accessed(self) -> None:
remote_media = self.recently_accessed_remotes
self.recently_accessed_remotes = set()
@@ -124,12 +127,12 @@ class MediaRepository:
local_media, remote_media, self.clock.time_msec()
)
- def mark_recently_accessed(self, server_name, media_id):
+ def mark_recently_accessed(self, server_name: Optional[str], media_id: str) -> None:
"""Mark the given media as recently accessed.
Args:
- server_name (str|None): Origin server of media, or None if local
- media_id (str): The media ID of the content
+ server_name: Origin server of media, or None if local
+ media_id: The media ID of the content
"""
if server_name:
self.recently_accessed_remotes.add((server_name, media_id))
@@ -459,7 +462,14 @@ class MediaRepository:
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
- def _generate_thumbnail(self, thumbnailer, t_width, t_height, t_method, t_type):
+ def _generate_thumbnail(
+ self,
+ thumbnailer: Thumbnailer,
+ t_width: int,
+ t_height: int,
+ t_method: str,
+ t_type: str,
+ ) -> Optional[BytesIO]:
m_width = thumbnailer.width
m_height = thumbnailer.height
@@ -470,22 +480,20 @@ class MediaRepository:
m_height,
self.max_image_pixels,
)
- return
+ return None
if thumbnailer.transpose_method is not None:
m_width, m_height = thumbnailer.transpose()
if t_method == "crop":
- t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
+ return thumbnailer.crop(t_width, t_height, t_type)
elif t_method == "scale":
t_width, t_height = thumbnailer.aspect(t_width, t_height)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
- t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
- else:
- t_byte_source = None
+ return thumbnailer.scale(t_width, t_height, t_type)
- return t_byte_source
+ return None
async def generate_local_exact_thumbnail(
self,
@@ -776,7 +784,7 @@ class MediaRepository:
return {"width": m_width, "height": m_height}
- async def delete_old_remote_media(self, before_ts):
+ async def delete_old_remote_media(self, before_ts: int) -> Dict[str, int]:
old_media = await self.store.get_remote_media_before(before_ts)
deleted = 0
@@ -928,7 +936,7 @@ class MediaRepositoryResource(Resource):
within a given rectangle.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
# If we're not configured to use it, raise if we somehow got here.
if not hs.config.can_load_media_repo:
raise ConfigError("Synapse is not configured to use a media repo.")
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 268e0c8f50..89cdd605aa 100644
--- a/synapse/rest/media/v1/media_storage.py
+++ b/synapse/rest/media/v1/media_storage.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vecotr Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +18,8 @@ import os
import shutil
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
+from twisted.internet.defer import Deferred
+from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.logging.context import defer_to_thread, make_deferred_yieldable
@@ -270,7 +272,7 @@ class MediaStorage:
return self.filepaths.local_media_filepath_rel(file_info.file_id)
-def _write_file_synchronously(source, dest):
+def _write_file_synchronously(source: IO, dest: IO) -> None:
"""Write `source` to the file like `dest` synchronously. Should be called
from a thread.
@@ -286,14 +288,14 @@ class FileResponder(Responder):
"""Wraps an open file that can be sent to a request.
Args:
- open_file (file): A file like object to be streamed ot the client,
+ open_file: A file like object to be streamed ot the client,
is closed when finished streaming.
"""
- def __init__(self, open_file):
+ def __init__(self, open_file: IO):
self.open_file = open_file
- def write_to_consumer(self, consumer):
+ def write_to_consumer(self, consumer: IConsumer) -> Deferred:
return make_deferred_yieldable(
FileSender().beginFileTransfer(self.open_file, consumer)
)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..bf3be653aa 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import datetime
import errno
import fnmatch
@@ -23,12 +23,13 @@ import re
import shutil
import sys
import traceback
-from typing import Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict, Generator, Iterable, Optional, Union
from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -41,6 +42,7 @@ from synapse.http.servlet import parse_integer, parse_string
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.media.v1._base import get_filename_from_headers
+from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.expiringcache import ExpiringCache
@@ -48,6 +50,12 @@ from synapse.util.stringutils import random_string
from ._base import FileInfo
+if TYPE_CHECKING:
+ from lxml import etree
+
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
_charset_match = re.compile(br"<\s*meta[^>]*charset\s*=\s*([a-z0-9-]+)", flags=re.I)
@@ -119,7 +127,12 @@ class OEmbedError(Exception):
class PreviewUrlResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.auth = hs.get_auth()
@@ -166,11 +179,11 @@ class PreviewUrlResource(DirectServeJsonResource):
self._start_expire_url_cache_data, 10 * 1000
)
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
request.setHeader(b"Allow", b"OPTIONS, GET")
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = await self.auth.get_user_by_req(request)
@@ -373,7 +386,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
- Params:
+ Args:
url: The URL to check.
Returns:
@@ -390,7 +403,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
- Params:
+ Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -450,7 +463,7 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.warning("Error downloading oEmbed metadata from %s: %r", url, e)
raise OEmbedError() from e
- async def _download_url(self, url: str, user):
+ async def _download_url(self, url: str, user: str) -> Dict[str, Any]:
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
@@ -580,7 +593,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"expire_url_cache_data", self._expire_url_cache_data
)
- async def _expire_url_cache_data(self):
+ async def _expire_url_cache_data(self) -> None:
"""Clean up expired url cache content, media and thumbnails.
"""
# TODO: Delete from backup media store
@@ -676,24 +689,54 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(
+ body: bytes, media_uri: str, request_encoding: Optional[str] = None
+) -> Dict[str, Optional[str]]:
+ """
+ Calculate metadata for an HTML document.
+
+ This uses lxml to parse the HTML document into the OG response. If errors
+ occur during processing of the document, an empty response is returned.
+
+ Args:
+ body: The HTML document, as bytes.
+ media_url: The URI used to download the body.
+ request_encoding: The character encoding of the body, as a string.
+
+ Returns:
+ The OG response as a dictionary.
+ """
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return {}
+
from lxml import etree
+ # Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body, parser)
- og = _calc_og(tree, media_uri)
+ except LookupError:
+ # blindly consider the encoding as utf-8.
+ parser = etree.HTMLParser(recover=True, encoding="utf-8")
+ except Exception as e:
+ logger.warning("Unable to create HTML parser: %s" % (e,))
+ return {}
+
+ def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ tree = etree.fromstring(body_attempt, parser)
+ return _calc_og(tree, media_uri)
+
+ # Attempt to parse the body. If this fails, log and return no metadata.
+ try:
+ return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
- parser = etree.HTMLParser(recover=True, encoding=request_encoding)
- tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
- og = _calc_og(tree, media_uri)
-
- return og
+ return _attempt_calc_og(body.decode("utf-8", "ignore"))
-def _calc_og(tree, media_uri):
+def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
@@ -797,7 +840,9 @@ def _calc_og(tree, media_uri):
for el in _iterate_over_text(tree.find("body"), *TAGS_TO_REMOVE)
)
og["og:description"] = summarize_paragraphs(text_nodes)
- else:
+ elif og["og:description"]:
+ # This must be a non-empty string at this point.
+ assert isinstance(og["og:description"], str)
og["og:description"] = summarize_paragraphs([og["og:description"]])
# TODO: delete the url downloads to stop diskfilling,
@@ -805,7 +850,9 @@ def _calc_og(tree, media_uri):
return og
-def _iterate_over_text(tree, *tags_to_ignore):
+def _iterate_over_text(
+ tree, *tags_to_ignore: Iterable[Union[str, "etree.Comment"]]
+) -> Generator[str, None, None]:
"""Iterate over the tree returning text nodes in a depth first fashion,
skipping text nodes inside certain tags.
"""
@@ -839,32 +886,32 @@ def _iterate_over_text(tree, *tags_to_ignore):
)
-def _rebase_url(url, base):
- base = list(urlparse.urlparse(base))
- url = list(urlparse.urlparse(url))
- if not url[0]: # fix up schema
- url[0] = base[0] or "http"
- if not url[1]: # fix up hostname
- url[1] = base[1]
- if not url[2].startswith("/"):
- url[2] = re.sub(r"/[^/]+$", "/", base[2]) + url[2]
- return urlparse.urlunparse(url)
+def _rebase_url(url: str, base: str) -> str:
+ base_parts = list(urlparse.urlparse(base))
+ url_parts = list(urlparse.urlparse(url))
+ if not url_parts[0]: # fix up schema
+ url_parts[0] = base_parts[0] or "http"
+ if not url_parts[1]: # fix up hostname
+ url_parts[1] = base_parts[1]
+ if not url_parts[2].startswith("/"):
+ url_parts[2] = re.sub(r"/[^/]+$", "/", base_parts[2]) + url_parts[2]
+ return urlparse.urlunparse(url_parts)
-def _is_media(content_type):
- if content_type.lower().startswith("image/"):
- return True
+def _is_media(content_type: str) -> bool:
+ return content_type.lower().startswith("image/")
-def _is_html(content_type):
+def _is_html(content_type: str) -> bool:
content_type = content_type.lower()
- if content_type.startswith("text/html") or content_type.startswith(
+ return content_type.startswith("text/html") or content_type.startswith(
"application/xhtml"
- ):
- return True
+ )
-def summarize_paragraphs(text_nodes, min_size=200, max_size=500):
+def summarize_paragraphs(
+ text_nodes: Iterable[str], min_size: int = 200, max_size: int = 500
+) -> Optional[str]:
# Try to get a summary of between 200 and 500 words, respecting
# first paragraph and then word boundaries.
# TODO: Respect sentences?
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..e92006faa9 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2018 New Vector Ltd
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,27 +13,32 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
+import abc
import logging
import os
import shutil
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
-class StorageProvider:
+
+class StorageProvider(metaclass=abc.ABCMeta):
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
- async def store_file(self, path: str, file_info: FileInfo):
+ @abc.abstractmethod
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
@@ -42,6 +47,7 @@ class StorageProvider:
file_info: The metadata of the file.
"""
+ @abc.abstractmethod
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
@@ -78,10 +84,10 @@ class StorageProviderWrapper(StorageProvider):
self.store_synchronous = store_synchronous
self.store_remote = store_remote
- def __str__(self):
+ def __str__(self) -> str:
return "StorageProviderWrapper[%s]" % (self.backend,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
if not file_info.server_name and not self.store_local:
return None
@@ -91,39 +97,34 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ await maybe_awaitable(self.backend.store_file(path, file_info)) # type: ignore
else:
# TODO: Handle errors.
async def store():
try:
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
except Exception:
logger.exception("Error storing file")
run_in_background(store)
- return None
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.fetch(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
"""A storage provider that stores files in a directory on a filesystem.
Args:
- hs (HomeServer)
+ hs
config: The config returned by `parse_config`.
"""
- def __init__(self, hs, config):
+ def __init__(self, hs: "HomeServer", config: str):
self.hs = hs
self.cache_directory = hs.config.media_store_path
self.base_directory = config
@@ -131,7 +132,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
- async def store_file(self, path, file_info):
+ async def store_file(self, path: str, file_info: FileInfo) -> None:
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -141,19 +142,21 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
- return await defer_to_thread(
+ await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
- async def fetch(self, path, file_info):
+ async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
if os.path.isfile(backup_fname):
return FileResponder(open(backup_fname, "rb"))
+ return None
+
@staticmethod
- def parse_config(config):
+ def parse_config(config: dict) -> str:
"""Called on startup to parse config supplied. This should parse
the config and raise if there is a problem.
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 30421b663a..d653a58be9 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
-# Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -15,10 +16,14 @@
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+
+from twisted.web.http import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_integer, parse_string
+from synapse.rest.media.v1.media_storage import MediaStorage
from ._base import (
FileInfo,
@@ -28,13 +33,22 @@ from ._base import (
respond_with_responder,
)
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class ThumbnailResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo, media_storage):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ media_repo: "MediaRepository",
+ media_storage: MediaStorage,
+ ):
super().__init__()
self.store = hs.get_datastore()
@@ -43,7 +57,7 @@ class ThumbnailResource(DirectServeJsonResource):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
- async def _async_render_GET(self, request):
+ async def _async_render_GET(self, request: Request) -> None:
set_cors_headers(request)
server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width", required=True)
@@ -73,8 +87,14 @@ class ThumbnailResource(DirectServeJsonResource):
self.media_repo.mark_recently_accessed(server_name, media_id)
async def _respond_local_thumbnail(
- self, request, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -86,41 +106,27 @@ class ThumbnailResource(DirectServeJsonResource):
return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
-
- if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
- )
-
- file_info = FileInfo(
- server_name=None,
- file_id=media_id,
- url_cache=media_info["url_cache"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
-
- responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
- else:
- logger.info("Couldn't find any generated thumbnails")
- respond_404(request)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_id,
+ url_cache=media_info["url_cache"],
+ server_name=None,
+ )
async def _select_or_generate_local_thumbnail(
self,
- request,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.store.get_local_media(media_id)
if not media_info:
@@ -178,14 +184,14 @@ class ThumbnailResource(DirectServeJsonResource):
async def _select_or_generate_remote_thumbnail(
self,
- request,
- server_name,
- media_id,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- ):
+ request: Request,
+ server_name: str,
+ media_id: str,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ ) -> None:
media_info = await self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = await self.store.get_remote_media_thumbnails(
@@ -239,8 +245,15 @@ class ThumbnailResource(DirectServeJsonResource):
raise SynapseError(400, "Failed to generate thumbnail.")
async def _respond_remote_thumbnail(
- self, request, server_name, media_id, width, height, method, m_type
- ):
+ self,
+ request: Request,
+ server_name: str,
+ media_id: str,
+ width: int,
+ height: int,
+ method: str,
+ m_type: str,
+ ) -> None:
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails.
@@ -249,97 +262,185 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id
)
+ await self._select_and_respond_with_thumbnail(
+ request,
+ width,
+ height,
+ method,
+ m_type,
+ thumbnail_infos,
+ media_info["filesystem_id"],
+ url_cache=None,
+ server_name=server_name,
+ )
+ async def _select_and_respond_with_thumbnail(
+ self,
+ request: Request,
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str] = None,
+ server_name: Optional[str] = None,
+ ) -> None:
+ """
+ Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ request: The incoming request.
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+ """
if thumbnail_infos:
- thumbnail_info = self._select_thumbnail(
- width, height, method, m_type, thumbnail_infos
+ file_info = self._select_thumbnail(
+ desired_width,
+ desired_height,
+ desired_method,
+ desired_type,
+ thumbnail_infos,
+ file_id,
+ url_cache,
+ server_name,
)
- file_info = FileInfo(
- server_name=server_name,
- file_id=media_info["filesystem_id"],
- thumbnail=True,
- thumbnail_width=thumbnail_info["thumbnail_width"],
- thumbnail_height=thumbnail_info["thumbnail_height"],
- thumbnail_type=thumbnail_info["thumbnail_type"],
- thumbnail_method=thumbnail_info["thumbnail_method"],
- )
-
- t_type = file_info.thumbnail_type
- t_length = thumbnail_info["thumbnail_length"]
+ if not file_info:
+ logger.info("Couldn't find a thumbnail matching the desired inputs")
+ respond_404(request)
+ return
responder = await self.media_storage.fetch_media(file_info)
- await respond_with_responder(request, responder, t_type, t_length)
+ await respond_with_responder(
+ request, responder, file_info.thumbnail_type, file_info.thumbnail_length
+ )
else:
logger.info("Failed to find any generated thumbnails")
respond_404(request)
def _select_thumbnail(
self,
- desired_width,
- desired_height,
- desired_method,
- desired_type,
- thumbnail_infos,
- ):
+ desired_width: int,
+ desired_height: int,
+ desired_method: str,
+ desired_type: str,
+ thumbnail_infos: List[Dict[str, Any]],
+ file_id: str,
+ url_cache: Optional[str],
+ server_name: Optional[str],
+ ) -> Optional[FileInfo]:
+ """
+ Choose an appropriate thumbnail from the previously generated thumbnails.
+
+ Args:
+ desired_width: The desired width, the returned thumbnail may be larger than this.
+ desired_height: The desired height, the returned thumbnail may be larger than this.
+ desired_method: The desired method used to generate the thumbnail.
+ desired_type: The desired content-type of the thumbnail.
+ thumbnail_infos: A list of dictionaries of candidate thumbnails.
+ file_id: The ID of the media that a thumbnail is being requested for.
+ url_cache: The URL cache value.
+ server_name: The server name, if this is a remote thumbnail.
+
+ Returns:
+ The thumbnail which best matches the desired parameters.
+ """
+ desired_method = desired_method.lower()
+
+ # The chosen thumbnail.
+ thumbnail_info = None
+
d_w = desired_width
d_h = desired_height
- if desired_method.lower() == "crop":
+ if desired_method == "crop":
+ # Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = []
+ # Other thumbnails.
crop_info_list2 = []
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "crop":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
- if t_method == "crop":
- aspect_quality = abs(d_w * t_h - d_h * t_w)
- min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
- size_quality = abs((d_w - t_w) * (d_h - t_h))
- type_quality = desired_type != info["thumbnail_type"]
- length_quality = info["thumbnail_length"]
- if t_w >= d_w or t_h >= d_h:
- crop_info_list.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
+ size_quality = abs((d_w - t_w) * (d_h - t_h))
+ type_quality = desired_type != info["thumbnail_type"]
+ length_quality = info["thumbnail_length"]
+ if t_w >= d_w or t_h >= d_h:
+ crop_info_list.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
- else:
- crop_info_list2.append(
- (
- aspect_quality,
- min_quality,
- size_quality,
- type_quality,
- length_quality,
- info,
- )
+ )
+ else:
+ crop_info_list2.append(
+ (
+ aspect_quality,
+ min_quality,
+ size_quality,
+ type_quality,
+ length_quality,
+ info,
)
+ )
if crop_info_list:
- return min(crop_info_list)[-1]
- else:
- return min(crop_info_list2)[-1]
- else:
+ thumbnail_info = min(crop_info_list)[-1]
+ elif crop_info_list2:
+ thumbnail_info = min(crop_info_list2)[-1]
+ elif desired_method == "scale":
+ # Thumbnails that match equal or larger sizes of desired width/height.
info_list = []
+ # Other thumbnails.
info_list2 = []
+
for info in thumbnail_infos:
+ # Skip thumbnails generated with different methods.
+ if info["thumbnail_method"] != "scale":
+ continue
+
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
- t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
- if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
+ if t_w >= d_w or t_h >= d_h:
info_list.append((size_quality, type_quality, length_quality, info))
- elif t_method == "scale":
+ else:
info_list2.append(
(size_quality, type_quality, length_quality, info)
)
if info_list:
- return min(info_list)[-1]
- else:
- return min(info_list2)[-1]
+ thumbnail_info = min(info_list)[-1]
+ elif info_list2:
+ thumbnail_info = min(info_list2)[-1]
+
+ if thumbnail_info:
+ return FileInfo(
+ file_id=file_id,
+ url_cache=url_cache,
+ server_name=server_name,
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ thumbnail_length=thumbnail_info["thumbnail_length"],
+ )
+
+ # No matching thumbnail was found.
+ return None
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 32a8e4f960..07903e4017 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,6 +15,7 @@
# limitations under the License.
import logging
from io import BytesIO
+from typing import Tuple
from PIL import Image
@@ -39,7 +41,7 @@ class Thumbnailer:
FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"}
- def __init__(self, input_path):
+ def __init__(self, input_path: str):
try:
self.image = Image.open(input_path)
except OSError as e:
@@ -59,11 +61,11 @@ class Thumbnailer:
# A lot of parsing errors can happen when parsing EXIF
logger.info("Error parsing image EXIF information: %s", e)
- def transpose(self):
+ def transpose(self) -> Tuple[int, int]:
"""Transpose the image using its EXIF Orientation tag
Returns:
- Tuple[int, int]: (width, height) containing the new image size in pixels.
+ A tuple containing the new image size in pixels as (width, height).
"""
if self.transpose_method is not None:
self.image = self.image.transpose(self.transpose_method)
@@ -73,7 +75,7 @@ class Thumbnailer:
self.image.info["exif"] = None
return self.image.size
- def aspect(self, max_width, max_height):
+ def aspect(self, max_width: int, max_height: int) -> Tuple[int, int]:
"""Calculate the largest size that preserves aspect ratio which
fits within the given rectangle::
@@ -91,7 +93,7 @@ class Thumbnailer:
else:
return (max_height * self.width) // self.height, max_height
- def _resize(self, width, height):
+ def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
# looks awful
@@ -99,7 +101,7 @@ class Thumbnailer:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
- def scale(self, width, height, output_type):
+ def scale(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales the image to the given dimensions.
Returns:
@@ -108,7 +110,7 @@ class Thumbnailer:
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
- def crop(self, width, height, output_type):
+ def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
aspect::
(w_in / h_in) = (w_scaled / h_scaled)
@@ -136,7 +138,7 @@ class Thumbnailer:
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
- def _encode_image(self, output_image, output_type):
+ def _encode_image(self, output_image: Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
fmt = self.FORMATS[output_type]
if fmt == "JPEG":
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..6da76ae994 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,18 +15,25 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_string
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+ from synapse.rest.media.v1.media_repository import MediaRepository
+
logger = logging.getLogger(__name__)
class UploadResource(DirectServeJsonResource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
@@ -37,14 +45,14 @@ class UploadResource(DirectServeJsonResource):
self.max_upload_size = hs.config.max_upload_size
self.clock = hs.get_clock()
- async def _async_render_OPTIONS(self, request):
+ async def _async_render_OPTIONS(self, request: Request) -> None:
respond_with_json(request, 200, {}, send_cors=True)
- async def _async_render_POST(self, request):
+ async def _async_render_POST(self, request: Request) -> None:
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode("ascii")
+ content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py
index c0b733488b..e5ef515090 100644
--- a/synapse/rest/synapse/client/__init__.py
+++ b/synapse/rest/synapse/client/__init__.py
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,3 +12,55 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from typing import TYPE_CHECKING, Mapping
+
+from twisted.web.resource import Resource
+
+from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource
+from synapse.rest.synapse.client.pick_idp import PickIdpResource
+from synapse.rest.synapse.client.pick_username import pick_username_resource
+from synapse.rest.synapse.client.sso_register import SsoRegisterResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+
+def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resource]:
+ """Builds a resource tree to include synapse-specific client resources
+
+ These are resources which should be loaded on all workers which expose a C-S API:
+ ie, the main process, and any generic workers so configured.
+
+ Returns:
+ map from path to Resource.
+ """
+ resources = {
+ # SSO bits. These are always loaded, whether or not SSO login is actually
+ # enabled (they just won't work very well if it's not)
+ "/_synapse/client/pick_idp": PickIdpResource(hs),
+ "/_synapse/client/pick_username": pick_username_resource(hs),
+ "/_synapse/client/new_user_consent": NewUserConsentResource(hs),
+ "/_synapse/client/sso_register": SsoRegisterResource(hs),
+ }
+
+ # provider-specific SSO bits. Only load these if they are enabled, since they
+ # rely on optional dependencies.
+ if hs.config.oidc_enabled:
+ from synapse.rest.synapse.client.oidc import OIDCResource
+
+ resources["/_synapse/client/oidc"] = OIDCResource(hs)
+
+ if hs.config.saml2_enabled:
+ from synapse.rest.synapse.client.saml2 import SAML2Resource
+
+ res = SAML2Resource(hs)
+ resources["/_synapse/client/saml2"] = res
+
+ # This is also mounted under '/_matrix' for backwards-compatibility.
+ resources["/_matrix/saml2"] = res
+
+ return resources
+
+
+__all__ = ["build_synapse_client_resource_tree"]
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
new file mode 100644
index 0000000000..b2e0f93810
--- /dev/null
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -0,0 +1,97 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource, respond_with_html
+from synapse.http.servlet import parse_string
+from synapse.types import UserID
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class NewUserConsentResource(DirectServeHtmlResource):
+ """A resource which collects consent to the server's terms from a new user
+
+ This resource gets mounted at /_synapse/client/new_user_consent, and is shown
+ when we are automatically creating a new user due to an SSO login.
+
+ It shows a template which prompts the user to go and read the Ts and Cs, and click
+ a clickybox if they have done so.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+ self._server_name = hs.hostname
+ self._consent_version = hs.config.consent.user_consent_version
+
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ user_id = UserID(session.chosen_localpart, self._server_name)
+ user_profile = {
+ "display_name": session.display_name,
+ }
+
+ template_params = {
+ "user_id": user_id.to_string(),
+ "user_profile": user_profile,
+ "consent_version": self._consent_version,
+ "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,),
+ }
+
+ template = self._jinja_env.get_template("sso_new_user_consent.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
+
+ async def _async_render_POST(self, request: Request):
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ accepted_version = parse_string(request, "accepted_version", required=True)
+ except SynapseError as e:
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
+
+ await self._sso_handler.handle_terms_accepted(
+ request, session_id, accepted_version
+ )
diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index d958dd65bb..64c0deb75d 100644
--- a/synapse/rest/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.oidc.callback_resource import OIDCCallbackResource
+from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
logger = logging.getLogger(__name__)
@@ -25,3 +26,6 @@ class OIDCResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
+
+
+__all__ = ["OIDCResource"]
diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py
index f7a0bc4bdb..f7a0bc4bdb 100644
--- a/synapse/rest/oidc/callback_resource.py
+++ b/synapse/rest/synapse/client/oidc/callback_resource.py
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
new file mode 100644
index 0000000000..9550b82998
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -0,0 +1,84 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ finish_request,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class PickIdpResource(DirectServeHtmlResource):
+ """IdP picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page
+ which prompts the user to choose an Identity Provider from the list.
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+ self._sso_login_idp_picker_template = (
+ hs.config.sso.sso_login_idp_picker_template
+ )
+ self._server_name = hs.hostname
+
+ async def _async_render_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(
+ request, "redirectUrl", required=True, encoding="utf-8"
+ )
+ idp = parse_string(request, "idp", required=False)
+
+ # if we need to pick an IdP, do so
+ if not idp:
+ return await self._serve_id_picker(request, client_redirect_url)
+
+ # otherwise, redirect to the IdP's redirect URI
+ providers = self._sso_handler.get_identity_providers()
+ auth_provider = providers.get(idp)
+ if not auth_provider:
+ logger.info("Unknown idp %r", idp)
+ self._sso_handler.render_error(
+ request, "unknown_idp", "Unknown identity provider ID"
+ )
+ return
+
+ sso_url = await auth_provider.handle_redirect_request(
+ request, client_redirect_url.encode("utf8")
+ )
+ logger.info("Redirecting to %s", sso_url)
+ request.redirect(sso_url)
+ finish_request(request)
+
+ async def _serve_id_picker(
+ self, request: SynapseRequest, client_redirect_url: str
+ ) -> None:
+ # otherwise, serve up the IdP picker
+ providers = self._sso_handler.get_identity_providers()
+ html = self._sso_login_idp_picker_template.render(
+ redirect_url=client_redirect_url,
+ server_name=self._server_name,
+ providers=providers.values(),
+ )
+ respond_with_html(request, 200, html)
diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py
new file mode 100644
index 0000000000..96077cfcd1
--- /dev/null
+++ b/synapse/rest/synapse/client/pick_username.py
@@ -0,0 +1,131 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, List
+
+from twisted.web.http import Request
+from twisted.web.resource import Resource
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ respond_with_html,
+)
+from synapse.http.servlet import parse_boolean, parse_string
+from synapse.http.site import SynapseRequest
+from synapse.util.templates import build_jinja_env
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+def pick_username_resource(hs: "HomeServer") -> Resource:
+ """Factory method to generate the username picker resource.
+
+ This resource gets mounted under /_synapse/client/pick_username and has two
+ children:
+
+ * "account_details": renders the form and handles the POSTed response
+ * "check": a JSON endpoint which checks if a userid is free.
+ """
+
+ res = Resource()
+ res.putChild(b"account_details", AccountDetailsResource(hs))
+ res.putChild(b"check", AvailabilityCheckResource(hs))
+
+ return res
+
+
+class AvailabilityCheckResource(DirectServeJsonResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request):
+ localpart = parse_string(request, "username", required=True)
+
+ session_id = get_username_mapping_session_cookie_from_request(request)
+
+ is_available = await self._sso_handler.check_username_availability(
+ localpart, session_id
+ )
+ return 200, {"available": is_available}
+
+
+class AccountDetailsResource(DirectServeHtmlResource):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ def template_search_dirs():
+ if hs.config.sso.sso_template_dir:
+ yield hs.config.sso.sso_template_dir
+ yield hs.config.sso.default_template_dir
+
+ self._jinja_env = build_jinja_env(template_search_dirs(), hs.config)
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ session = self._sso_handler.get_mapping_session(session_id)
+ except SynapseError as e:
+ logger.warning("Error fetching session: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ idp_id = session.auth_provider_id
+ template_params = {
+ "idp": self._sso_handler.get_identity_providers()[idp_id],
+ "user_attributes": {
+ "display_name": session.display_name,
+ "emails": session.emails,
+ },
+ }
+
+ template = self._jinja_env.get_template("sso_auth_account_details.html")
+ html = template.render(template_params)
+ respond_with_html(request, 200, html)
+
+ async def _async_render_POST(self, request: SynapseRequest):
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+
+ try:
+ localpart = parse_string(request, "username", required=True)
+ use_display_name = parse_boolean(request, "use_display_name", default=False)
+
+ try:
+ emails_to_use = [
+ val.decode("utf-8") for val in request.args.get(b"use_email", [])
+ ] # type: List[str]
+ except ValueError:
+ raise SynapseError(400, "Query parameter use_email must be utf-8")
+ except SynapseError as e:
+ logger.warning("[session %s] bad param: %s", session_id, e)
+ self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code)
+ return
+
+ await self._sso_handler.handle_submit_username_request(
+ request, session_id, localpart, use_display_name, emails_to_use
+ )
diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py
index 68da37ca6a..3e8235ee1e 100644
--- a/synapse/rest/saml2/__init__.py
+++ b/synapse/rest/synapse/client/saml2/__init__.py
@@ -12,12 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import logging
from twisted.web.resource import Resource
-from synapse.rest.saml2.metadata_resource import SAML2MetadataResource
-from synapse.rest.saml2.response_resource import SAML2ResponseResource
+from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource
+from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource
logger = logging.getLogger(__name__)
@@ -27,3 +28,6 @@ class SAML2Resource(Resource):
Resource.__init__(self)
self.putChild(b"metadata.xml", SAML2MetadataResource(hs))
self.putChild(b"authn_response", SAML2ResponseResource(hs))
+
+
+__all__ = ["SAML2Resource"]
diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py
index 1e8526e22e..1e8526e22e 100644
--- a/synapse/rest/saml2/metadata_resource.py
+++ b/synapse/rest/synapse/client/saml2/metadata_resource.py
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py
index f6668fb5e3..f6668fb5e3 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/synapse/client/saml2/response_resource.py
diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py
new file mode 100644
index 0000000000..dfefeb7796
--- /dev/null
+++ b/synapse/rest/synapse/client/sso_register.py
@@ -0,0 +1,50 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from twisted.web.http import Request
+
+from synapse.api.errors import SynapseError
+from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
+from synapse.http.server import DirectServeHtmlResource
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class SsoRegisterResource(DirectServeHtmlResource):
+ """A resource which completes SSO registration
+
+ This resource gets mounted at /_synapse/client/sso_register, and is shown
+ after we collect username and/or consent for a new SSO user. It (finally) registers
+ the user, and confirms redirect to the client
+ """
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._sso_handler = hs.get_sso_handler()
+
+ async def _async_render_GET(self, request: Request) -> None:
+ try:
+ session_id = get_username_mapping_session_cookie_from_request(request)
+ except SynapseError as e:
+ logger.warning("Error fetching session cookie: %s", e)
+ self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
+ return
+ await self._sso_handler.register_sso_user(request, session_id)
|