diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py
index b712215112..28542cd774 100644
--- a/synapse/rest/__init__.py
+++ b/synapse/rest/__init__.py
@@ -30,6 +30,7 @@ from synapse.rest.client import (
keys,
knock,
login as v1_login,
+ login_token_request,
logout,
mutual_rooms,
notifications,
@@ -43,6 +44,7 @@ from synapse.rest.client import (
receipts,
register,
relations,
+ rendezvous,
report_event,
room,
room_batch,
@@ -130,3 +132,5 @@ class ClientRestResource(JsonResource):
# unstable
mutual_rooms.register_servlets(hs, client_resource)
+ login_token_request.register_servlets(hs, client_resource)
+ rendezvous.register_servlets(hs, client_resource)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index fa3266720b..fb73886df0 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -61,9 +61,11 @@ from synapse.rest.admin.rooms import (
MakeRoomAdminRestServlet,
RoomEventContextServlet,
RoomMembersRestServlet,
+ RoomMessagesRestServlet,
RoomRestServlet,
RoomRestV2Servlet,
RoomStateRestServlet,
+ RoomTimestampToEventRestServlet,
)
from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
from synapse.rest.admin.statistics import UserMediaStatisticsRestServlet
@@ -78,6 +80,8 @@ from synapse.rest.admin.users import (
SearchUsersRestServlet,
ShadowBanRestServlet,
UserAdminServlet,
+ UserByExternalId,
+ UserByThreePid,
UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
@@ -234,6 +238,10 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
"""
Register all the admin servlets.
"""
+ # Admin servlets aren't registered on workers.
+ if hs.config.worker.worker_app is not None:
+ return
+
register_servlets_for_client_rest_resource(hs, http_server)
BlockRoomRestServlet(hs).register(http_server)
ListRoomRestServlet(hs).register(http_server)
@@ -250,9 +258,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
- DeviceRestServlet(hs).register(http_server)
- DevicesRestServlet(hs).register(http_server)
- DeleteDevicesRestServlet(hs).register(http_server)
UserMediaStatisticsRestServlet(hs).register(http_server)
EventReportDetailRestServlet(hs).register(http_server)
EventReportsRestServlet(hs).register(http_server)
@@ -271,13 +276,18 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
DestinationResetConnectionRestServlet(hs).register(http_server)
DestinationRestServlet(hs).register(http_server)
ListDestinationsRestServlet(hs).register(http_server)
+ RoomMessagesRestServlet(hs).register(http_server)
+ RoomTimestampToEventRestServlet(hs).register(http_server)
+ UserByExternalId(hs).register(http_server)
+ UserByThreePid(hs).register(http_server)
- # Some servlets only get registered for the main process.
- if hs.config.worker.worker_app is None:
- SendServerNoticeServlet(hs).register(http_server)
- BackgroundUpdateEnabledRestServlet(hs).register(http_server)
- BackgroundUpdateRestServlet(hs).register(http_server)
- BackgroundUpdateStartJobRestServlet(hs).register(http_server)
+ DeviceRestServlet(hs).register(http_server)
+ DevicesRestServlet(hs).register(http_server)
+ DeleteDevicesRestServlet(hs).register(http_server)
+ SendServerNoticeServlet(hs).register(http_server)
+ BackgroundUpdateEnabledRestServlet(hs).register(http_server)
+ BackgroundUpdateRestServlet(hs).register(http_server)
+ BackgroundUpdateStartJobRestServlet(hs).register(http_server)
def register_servlets_for_client_rest_resource(
@@ -286,9 +296,11 @@ def register_servlets_for_client_rest_resource(
"""Register only the servlets which need to be exposed on /_matrix/client/xxx"""
WhoisRestServlet(hs).register(http_server)
PurgeHistoryStatusRestServlet(hs).register(http_server)
- DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
- ResetPasswordRestServlet(hs).register(http_server)
+ # The following resources can only be run on the main process.
+ if hs.config.worker.worker_app is None:
+ DeactivateAccountRestServlet(hs).register(http_server)
+ ResetPasswordRestServlet(hs).register(http_server)
SearchUsersRestServlet(hs).register(http_server)
UserRegisterServlet(hs).register(http_server)
AccountValidityRenewServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index 399b205aaf..b467a61dfb 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -19,7 +19,7 @@ from typing import Iterable, Pattern
from synapse.api.auth import Auth
from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest
-from synapse.types import UserID
+from synapse.types import Requester
def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
@@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None
AuthError if the requester is not a server admin
"""
requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
+ await assert_user_is_admin(auth, requester)
-async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None:
+async def assert_user_is_admin(auth: Auth, requester: Requester) -> None:
"""Verify that the given user is an admin user
Args:
auth: Auth singleton
- user_id: user to check
+ requester: The user making the request, according to the access token.
Raises:
AuthError if the user is not a server admin
"""
- is_admin = await auth.is_server_admin(user_id)
+ is_admin = await auth.is_server_admin(requester)
if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")
diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index d934880102..3b2f2d9abb 100644
--- a/synapse/rest/admin/devices.py
+++ b/synapse/rest/admin/devices.py
@@ -16,6 +16,7 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
@@ -43,7 +44,9 @@ class DeviceRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -112,7 +115,9 @@ class DevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
@@ -143,7 +148,9 @@ class DeleteDevicesRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index 19d4a008e8..73470f09ae 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining room: %s", room_id)
@@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by user: %s", user_id)
@@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet):
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by ID: %s/%s", server_name, media_id)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 9d953d58de..747e6fda83 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -35,6 +35,7 @@ from synapse.rest.admin._base import (
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
+from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@@ -75,7 +76,7 @@ class RoomRestV2Servlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request)
@@ -303,6 +304,7 @@ class RoomRestServlet(RestServlet):
members = await self.store.get_users_in_room(room_id)
ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+ ret["forgotten"] = await self.store.is_locally_forgotten_room(room_id)
return HTTPStatus.OK, ret
@@ -326,7 +328,7 @@ class RoomRestServlet(RestServlet):
pagination_handler: "PaginationHandler",
) -> Tuple[int, JsonDict]:
requester = await auth.get_user_by_req(request)
- await assert_user_is_admin(auth, requester.user)
+ await assert_user_is_admin(auth, requester)
content = parse_json_object_from_request(request)
@@ -460,7 +462,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert request.args is not None
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request)
@@ -550,7 +552,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request, allow_empty_body=True)
room_id, _ = await self.resolve_room_id(room_identifier)
@@ -741,7 +743,7 @@ class RoomEventContextServlet(RestServlet):
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
limit = parse_integer(request, "limit", default=10)
@@ -833,7 +835,7 @@ class BlockRoomRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await assert_user_is_admin(self._auth, requester.user)
+ await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request)
@@ -857,3 +859,106 @@ class BlockRoomRestServlet(RestServlet):
await self._store.unblock_room(room_id)
return HTTPStatus.OK, {"block": block}
+
+
+class RoomMessagesRestServlet(RestServlet):
+ """
+ Get messages list of a room.
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/messages$")
+
+ def __init__(self, hs: "HomeServer"):
+ self._hs = hs
+ self._clock = hs.get_clock()
+ self._pagination_handler = hs.get_pagination_handler()
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester)
+
+ pagination_config = await PaginationConfig.from_request(
+ self._store, request, default_limit=10
+ )
+ # Twisted will have processed the args by now.
+ assert request.args is not None
+ as_client_event = b"raw" not in request.args
+ filter_str = parse_string(request, "filter", encoding="utf-8")
+ if filter_str:
+ filter_json = urlparse.unquote(filter_str)
+ event_filter: Optional[Filter] = Filter(
+ self._hs, json_decoder.decode(filter_json)
+ )
+ if (
+ event_filter
+ and event_filter.filter_json.get("event_format", "client")
+ == "federation"
+ ):
+ as_client_event = False
+ else:
+ event_filter = None
+
+ msgs = await self._pagination_handler.get_messages(
+ room_id=room_id,
+ requester=requester,
+ pagin_config=pagination_config,
+ as_client_event=as_client_event,
+ event_filter=event_filter,
+ use_admin_priviledge=True,
+ )
+
+ return HTTPStatus.OK, msgs
+
+
+class RoomTimestampToEventRestServlet(RestServlet):
+ """
+ API endpoint to fetch the `event_id` of the closest event to the given
+ timestamp (`ts` query parameter) in the given direction (`dir` query
+ parameter).
+
+ Useful for cases like jump to date so you can start paginating messages from
+ a given date in the archive.
+
+ `ts` is a timestamp in milliseconds where we will find the closest event in
+ the given direction.
+
+ `dir` can be `f` or `b` to indicate forwards and backwards in time from the
+ given timestamp.
+
+ GET /_synapse/admin/v1/rooms/<roomID>/timestamp_to_event?ts=<timestamp>&dir=<direction>
+ {
+ "event_id": ...
+ }
+ """
+
+ PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/timestamp_to_event$")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
+ self._timestamp_lookup_handler = hs.get_timestamp_lookup_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self._auth.get_user_by_req(request)
+ await assert_user_is_admin(self._auth, requester)
+
+ timestamp = parse_integer(request, "ts", required=True)
+ direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
+
+ (
+ event_id,
+ origin_server_ts,
+ ) = await self._timestamp_lookup_handler.get_event_for_timestamp(
+ requester, room_id, timestamp, direction
+ )
+
+ return HTTPStatus.OK, {
+ "event_id": event_id,
+ "origin_server_ts": origin_server_ts,
+ }
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index ba2f7fa6d8..6e0c44be2a 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -69,6 +69,7 @@ class UsersRestServletV2(RestServlet):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
+ self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
@@ -95,6 +96,13 @@ class UsersRestServletV2(RestServlet):
guests = parse_boolean(request, "guests", default=True)
deactivated = parse_boolean(request, "deactivated", default=False)
+ # If support for MSC3866 is not enabled, apply no filtering based on the
+ # `approved` column.
+ if self._msc3866_enabled:
+ approved = parse_boolean(request, "approved", default=True)
+ else:
+ approved = True
+
order_by = parse_string(
request,
"order_by",
@@ -115,8 +123,22 @@ class UsersRestServletV2(RestServlet):
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b"))
users, total = await self.store.get_users_paginate(
- start, limit, user_id, name, guests, deactivated, order_by, direction
+ start,
+ limit,
+ user_id,
+ name,
+ guests,
+ deactivated,
+ order_by,
+ direction,
+ approved,
)
+
+ # If support for MSC3866 is not enabled, don't show the approval flag.
+ if not self._msc3866_enabled:
+ for user in users:
+ del user["approved"]
+
ret = {"users": users, "total": total}
if (start + limit) < total:
ret["next_token"] = str(start + len(users))
@@ -163,6 +185,7 @@ class UserRestServletV2(RestServlet):
self.deactivate_account_handler = hs.get_deactivate_account_handler()
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
+ self._msc3866_enabled = hs.config.experimental.msc3866.enabled
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -183,7 +206,7 @@ class UserRestServletV2(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
target_user = UserID.from_string(user_id)
body = parse_json_object_from_request(request)
@@ -239,6 +262,15 @@ class UserRestServletV2(RestServlet):
HTTPStatus.BAD_REQUEST, "'deactivated' parameter is not of type boolean"
)
+ approved: Optional[bool] = None
+ if "approved" in body and self._msc3866_enabled:
+ approved = body["approved"]
+ if not isinstance(approved, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "'approved' parameter is not of type boolean",
+ )
+
# convert List[Dict[str, str]] into List[Tuple[str, str]]
if external_ids is not None:
new_external_ids = [
@@ -343,6 +375,9 @@ class UserRestServletV2(RestServlet):
if "user_type" in body:
await self.store.set_user_type(target_user, user_type)
+ if approved is not None:
+ await self.store.update_user_approval_status(target_user, approved)
+
user = await self.admin_handler.get_user(target_user)
assert user is not None
@@ -355,6 +390,10 @@ class UserRestServletV2(RestServlet):
if password is not None:
password_hash = await self.auth_handler.hash(password)
+ new_user_approved = True
+ if self._msc3866_enabled and approved is not None:
+ new_user_approved = approved
+
user_id = await self.registration_handler.register_user(
localpart=target_user.localpart,
password_hash=password_hash,
@@ -362,6 +401,7 @@ class UserRestServletV2(RestServlet):
default_display_name=displayname,
user_type=user_type,
by_admin=True,
+ approved=new_user_approved,
)
if threepids is not None:
@@ -375,7 +415,7 @@ class UserRestServletV2(RestServlet):
and self.hs.config.email.email_notif_for_new_users
and medium == "email"
):
- await self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_or_update_pusher(
user_id=user_id,
access_token=None,
kind="email",
@@ -383,7 +423,7 @@ class UserRestServletV2(RestServlet):
app_display_name="Email Notifications",
device_display_name=address,
pushkey=address,
- lang=None, # We don't know a user's language here
+ lang=None,
data={},
)
@@ -550,6 +590,7 @@ class UserRegisterServlet(RestServlet):
user_type=user_type,
default_display_name=displayname,
by_admin=True,
+ approved=True,
)
result = await register._create_registration_details(user_id, body)
@@ -575,10 +616,9 @@ class WhoisRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
- auth_user = requester.user
- if target_user != auth_user:
- await assert_user_is_admin(self.auth, auth_user)
+ if target_user != requester.user:
+ await assert_user_is_admin(self.auth, requester)
if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
@@ -601,7 +641,7 @@ class DeactivateAccountRestServlet(RestServlet):
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError(
@@ -693,7 +733,7 @@ class ResetPasswordRestServlet(RestServlet):
This needs user to have administrator access in Synapse.
"""
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
UserID.from_string(target_user_id)
@@ -807,7 +847,7 @@ class UserAdminServlet(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
auth_user = requester.user
target_user = UserID.from_string(user_id)
@@ -863,8 +903,9 @@ class PushersRestServlet(RestServlet):
@user:server/pushers
Returns:
- pushers: Dictionary containing pushers information.
- total: Number of pushers in dictionary `pushers`.
+ A dictionary with keys:
+ pushers: Dictionary containing pushers information.
+ total: Number of pushers in dictionary `pushers`.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
@@ -921,7 +962,7 @@ class UserTokenRestServlet(RestServlet):
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- await assert_user_is_admin(self.auth, requester.user)
+ await assert_user_is_admin(self.auth, requester)
auth_user = requester.user
if not self.is_mine_id(user_id):
@@ -1157,3 +1198,55 @@ class AccountDataRestServlet(RestServlet):
"rooms": by_room_data,
},
}
+
+
+class UserByExternalId(RestServlet):
+ """Find a user based on an external ID from an auth provider"""
+
+ PATTERNS = admin_patterns(
+ "/auth_providers/(?P<provider>[^/]*)/users/(?P<external_id>[^/]*)"
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
+
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ provider: str,
+ external_id: str,
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ user_id = await self._store.get_user_by_external_id(provider, external_id)
+
+ if user_id is None:
+ raise NotFoundError("User not found")
+
+ return HTTPStatus.OK, {"user_id": user_id}
+
+
+class UserByThreePid(RestServlet):
+ """Find a user based on 3PID of a particular medium"""
+
+ PATTERNS = admin_patterns("/threepid/(?P<medium>[^/]*)/users/(?P<address>[^/]*)")
+
+ def __init__(self, hs: "HomeServer"):
+ self._auth = hs.get_auth()
+ self._store = hs.get_datastores().main
+
+ async def on_GET(
+ self,
+ request: SynapseRequest,
+ medium: str,
+ address: str,
+ ) -> Tuple[int, JsonDict]:
+ await assert_requester_is_admin(self._auth, request)
+
+ user_id = await self._store.get_user_id_by_threepid(medium, address)
+
+ if user_id is None:
+ raise NotFoundError("User not found")
+
+ return HTTPStatus.OK, {"user_id": user_id}
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index 50edc6b7d3..44f622bcce 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -15,10 +15,12 @@
# limitations under the License.
import logging
import random
-from http import HTTPStatus
-from typing import TYPE_CHECKING, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
+from pydantic import StrictBool, StrictStr, constr
+from typing_extensions import Literal
+
from twisted.web.server import Request
from synapse.api.constants import LoginType
@@ -28,18 +30,25 @@ from synapse.api.errors import (
SynapseError,
ThreepidValidationError,
)
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http.server import HttpServer, finish_request, respond_with_html
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
+ parse_and_validate_json_object_from_request,
parse_json_object_from_request,
parse_string,
)
from synapse.http.site import SynapseRequest
from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
+from synapse.rest.client.models import (
+ AuthenticationData,
+ ClientSecretStr,
+ EmailRequestTokenBody,
+ MsisdnRequestTokenBody,
+)
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -64,7 +73,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
self.config = hs.config
self.identity_handler = hs.get_identity_handler()
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -73,41 +82,24 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "User password resets have been disabled due to lack of email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "User password resets have been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based password resets have been disabled on this server"
)
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
-
- # Extract params from body
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
-
- # Canonicalise the email address. The addresses are all stored canonicalised
- # in the database. This allows the user to reset his password without having to
- # know the exact spelling (eg. upper and lower case) of address in the database.
- # Stored in the database "foo@bar.com"
- # User requests with "FOO@bar.com" would raise a Not Found error
- try:
- email = validate_email(body["email"])
- except ValueError as e:
- raise SynapseError(400, str(e))
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
+ body = parse_and_validate_json_object_from_request(
+ request, EmailRequestTokenBody
+ )
- if next_link:
+ if body.next_link:
# Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ assert_valid_next_link(self.hs, body.next_link)
await self.identity_handler.ratelimit_request_token_requests(
- request, "email", email
+ request, "email", body.email
)
# The email will be sent to the stored address.
@@ -115,7 +107,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# an email address which is controlled by the attacker but which, after
# canonicalisation, matches the one in our database.
existing_user_id = await self.hs.get_datastores().main.get_user_id_by_threepid(
- "email", email
+ "email", body.email
)
if existing_user_id is None:
@@ -129,35 +121,20 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.request_email_token(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send password reset emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_password_reset_mail,
- next_link,
- )
-
- # Wrap the session id in a JSON object
- ret = {"sid": sid}
-
+ # Send password reset emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ body.email,
+ body.client_secret,
+ body.send_attempt,
+ self.mailer.send_password_reset_mail,
+ body.next_link,
+ )
threepid_send_requests.labels(type="email", reason="password_reset").observe(
- send_attempt
+ body.send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class PasswordRestServlet(RestServlet):
@@ -172,16 +149,23 @@ class PasswordRestServlet(RestServlet):
self.password_policy_handler = hs.get_password_policy_handler()
self._set_password_handler = hs.get_set_password_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData] = None
+ logout_devices: StrictBool = True
+ if TYPE_CHECKING:
+ # workaround for https://github.com/samuelcolvin/pydantic/issues/156
+ new_password: Optional[StrictStr] = None
+ else:
+ new_password: Optional[constr(max_length=512, strict=True)] = None
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
# we do basic sanity checks here because the auth layer will store these
# in sessions. Pull out the new password provided to us.
- new_password = body.pop("new_password", None)
+ new_password = body.new_password
if new_password is not None:
- if not isinstance(new_password, str) or len(new_password) > 512:
- raise SynapseError(400, "Invalid password")
self.password_policy_handler.validate_password(new_password)
# there are two possibilities here. Either the user does not have an
@@ -201,7 +185,7 @@ class PasswordRestServlet(RestServlet):
params, session_id = await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -224,7 +208,7 @@ class PasswordRestServlet(RestServlet):
result, params, session_id = await self.auth_handler.check_ui_auth(
[[LoginType.EMAIL_IDENTITY]],
request,
- body,
+ body.dict(exclude_unset=True),
"modify your account password",
)
except InteractiveAuthIncompleteError as e:
@@ -299,37 +283,33 @@ class DeactivateAccountRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self._deactivate_account_handler = hs.get_deactivate_account_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData] = None
+ id_server: Optional[StrictStr] = None
+ # Not specced, see https://github.com/matrix-org/matrix-spec/issues/297
+ erase: StrictBool = False
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
- erase = body.get("erase", False)
- if not isinstance(erase, bool):
- raise SynapseError(
- HTTPStatus.BAD_REQUEST,
- "Param 'erase' must be a boolean, if given",
- Codes.BAD_JSON,
- )
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(), erase, requester
+ requester.user.to_string(), body.erase, requester
)
return 200, {}
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"deactivate your account",
)
result = await self._deactivate_account_handler.deactivate_account(
- requester.user.to_string(),
- erase,
- requester,
- id_server=body.get("id_server"),
+ requester.user.to_string(), body.erase, requester, id_server=body.id_server
)
if result:
id_server_unbind_result = "success"
@@ -349,7 +329,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.store = self.hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -358,34 +338,20 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
+ )
raise SynapseError(
- 400, "Adding an email to your account is disabled on this server"
+ 400,
+ "Adding an email to your account is disabled on this server",
)
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["client_secret", "email", "send_attempt"])
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
-
- # Canonicalise the email address. The addresses are all stored canonicalised
- # in the database.
- # This ensures that the validation email is sent to the canonicalised address
- # as it will later be entered into the database.
- # Otherwise the email will be sent to "FOO@bar.com" and stored as
- # "foo@bar.com" in database.
- try:
- email = validate_email(body["email"])
- except ValueError as e:
- raise SynapseError(400, str(e))
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
+ body = parse_and_validate_json_object_from_request(
+ request, EmailRequestTokenBody
+ )
- if not await check_3pid_allowed(self.hs, "email", email):
+ if not await check_3pid_allowed(self.hs, "email", body.email):
raise SynapseError(
403,
"Your email domain is not authorized on this server",
@@ -393,14 +359,14 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
)
await self.identity_handler.ratelimit_request_token_requests(
- request, "email", email
+ request, "email", body.email
)
- if next_link:
+ if body.next_link:
# Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ assert_valid_next_link(self.hs, body.next_link)
- existing_user_id = await self.store.get_user_id_by_threepid("email", email)
+ existing_user_id = await self.store.get_user_id_by_threepid("email", body.email)
if existing_user_id is not None:
if self.config.server.request_token_inhibit_3pid_errors:
@@ -413,35 +379,21 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.request_email_token(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send threepid validation emails from Synapse
- sid = await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_add_threepid_mail,
- next_link,
- )
-
- # Wrap the session id in a JSON object
- ret = {"sid": sid}
+ # Send threepid validation emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ body.email,
+ body.client_secret,
+ body.send_attempt,
+ self.mailer.send_add_threepid_mail,
+ body.next_link,
+ )
threepid_send_requests.labels(type="email", reason="add_threepid").observe(
- send_attempt
+ body.send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class MsisdnThreepidRequestTokenRestServlet(RestServlet):
@@ -454,23 +406,16 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
- assert_params_in_dict(
- body, ["client_secret", "country", "phone_number", "send_attempt"]
+ body = parse_and_validate_json_object_from_request(
+ request, MsisdnRequestTokenBody
)
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
-
- country = body["country"]
- phone_number = body["phone_number"]
- send_attempt = body["send_attempt"]
- next_link = body.get("next_link") # Optional param
-
- msisdn = phone_number_to_msisdn(country, phone_number)
+ msisdn = phone_number_to_msisdn(body.country, body.phone_number)
if not await check_3pid_allowed(self.hs, "msisdn", msisdn):
raise SynapseError(
403,
+ # TODO: is this error message accurate? Looks like we've only rejected
+ # this phone number, not necessarily all phone numbers
"Account phone numbers are not authorized on this server",
Codes.THREEPID_DENIED,
)
@@ -479,9 +424,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
request, "msisdn", msisdn
)
- if next_link:
+ if body.next_link:
# Raise if the provided next_link value isn't valid
- assert_valid_next_link(self.hs, next_link)
+ assert_valid_next_link(self.hs, body.next_link)
existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn)
@@ -508,15 +453,15 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
ret = await self.identity_handler.requestMsisdnToken(
self.hs.config.registration.account_threepid_delegate_msisdn,
- country,
- phone_number,
- client_secret,
- send_attempt,
- next_link,
+ body.country,
+ body.phone_number,
+ body.client_secret,
+ body.send_attempt,
+ body.next_link,
)
threepid_send_requests.labels(type="msisdn", reason="add_threepid").observe(
- send_attempt
+ body.send_attempt
)
return 200, ret
@@ -534,24 +479,18 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet):
self.config = hs.config
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self._failure_email_template = (
self.config.email.email_add_threepid_template_failure_html
)
async def on_GET(self, request: Request) -> None:
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "Adding emails have been disabled due to lack of an email config"
- )
- raise SynapseError(
- 400, "Adding an email to your account is disabled on this server"
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "Adding emails have been disabled due to lack of an email config"
)
- elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
raise SynapseError(
- 400,
- "This homeserver is not validating threepids.",
+ 400, "Adding an email to your account is disabled on this server"
)
sid = parse_string(request, "sid", required=True)
@@ -595,6 +534,11 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)
+ class PostBody(RequestBodyModel):
+ client_secret: ClientSecretStr
+ sid: StrictStr
+ token: StrictStr
+
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -610,16 +554,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"instead.",
)
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["client_secret", "sid", "token"])
- assert_valid_client_secret(body["client_secret"])
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
# Proxy submit_token request to msisdn threepid delegate
response = await self.identity_handler.proxy_msisdn_submit_token(
self.config.registration.account_threepid_delegate_msisdn,
- body["client_secret"],
- body["sid"],
- body["token"],
+ body.client_secret,
+ body.sid,
+ body.token,
)
return 200, response
@@ -642,6 +584,10 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
+ # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
+ # the endpoint is deprecated. (If you really want to, you could do this by reusing
+ # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle
+ # `threePidCreds` versus `three_pid_creds`.
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
@@ -690,6 +636,11 @@ class ThreepidAddRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData] = None
+ client_secret: ClientSecretStr
+ sid: StrictStr
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
@@ -699,22 +650,17 @@ class ThreepidAddRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
- body = parse_json_object_from_request(request)
-
- assert_params_in_dict(body, ["client_secret", "sid"])
- sid = body["sid"]
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
- client_secret, sid
+ body.client_secret, body.sid
)
if validation_session:
await self.auth_handler.add_threepid(
@@ -739,23 +685,20 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
- async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- body = parse_json_object_from_request(request)
+ class PostBody(RequestBodyModel):
+ client_secret: ClientSecretStr
+ id_access_token: StrictStr
+ id_server: StrictStr
+ sid: StrictStr
- assert_params_in_dict(
- body, ["id_server", "sid", "id_access_token", "client_secret"]
- )
- id_server = body["id_server"]
- sid = body["sid"]
- id_access_token = body["id_access_token"]
- client_secret = body["client_secret"]
- assert_valid_client_secret(client_secret)
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
await self.identity_handler.bind_threepid(
- client_secret, sid, user_id, id_server, id_access_token
+ body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
)
return 200, {}
@@ -771,23 +714,27 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main
+ class PostBody(RequestBodyModel):
+ address: StrictStr
+ id_server: Optional[StrictStr] = None
+ medium: Literal["email", "msisdn"]
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["medium", "address"])
-
- medium = body.get("medium")
- address = body.get("address")
- id_server = body.get("id_server")
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
- {"address": address, "medium": medium, "id_server": id_server},
+ {
+ "address": body.address,
+ "medium": body.medium,
+ "id_server": body.id_server,
+ },
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
@@ -801,21 +748,25 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
+ class PostBody(RequestBodyModel):
+ address: StrictStr
+ id_server: Optional[StrictStr] = None
+ medium: Literal["email", "msisdn"]
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
- body = parse_json_object_from_request(request)
- assert_params_in_dict(body, ["medium", "address"])
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
ret = await self.auth_handler.delete_threepid(
- user_id, body["medium"], body["address"], body.get("id_server")
+ user_id, body.medium, body.address, body.id_server
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
@@ -905,17 +856,18 @@ class AccountStatusRestServlet(RestServlet):
self._auth = hs.get_auth()
self._account_handler = hs.get_account_handler()
+ class PostBody(RequestBodyModel):
+ # TODO: we could validate that each user id is an mxid here, and/or parse it
+ # as a UserID
+ user_ids: List[StrictStr]
+
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await self._auth.get_user_by_req(request)
- body = parse_json_object_from_request(request)
- if "user_ids" not in body:
- raise SynapseError(
- 400, "Required parameter 'user_ids' is missing", Codes.MISSING_PARAM
- )
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
statuses, failures = await self._account_handler.get_account_statuses(
- body["user_ids"],
+ body.user_ids,
allow_remote=True,
)
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 4237071c61..e84dde31b1 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -77,6 +77,11 @@ class CapabilitiesRestServlet(RestServlet):
"enabled": True,
}
+ if self.config.experimental.msc3664_enabled:
+ response["capabilities"]["im.nheko.msc3664.related_event_match"] = {
+ "enabled": self.config.experimental.msc3664_enabled,
+ }
+
return HTTPStatus.OK, response
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 6fab102437..69b803f9f8 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,18 +14,22 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from pydantic import Extra, StrictStr
from synapse.api import errors
from synapse.api.errors import NotFoundError
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
- assert_params_in_dict,
- parse_json_object_from_request,
+ parse_and_validate_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
+from synapse.rest.client.models import AuthenticationData
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -42,12 +46,26 @@ class DevicesRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
devices = await self.device_handler.get_devices_by_user(
requester.user.to_string()
)
+
+ # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+ # removed from each device. If it is enabled, then the field name will
+ # be replaced by the unstable identifier.
+ #
+ # When MSC3852 is accepted, this block of code can just be removed to
+ # expose "last_seen_user_agent" to clients.
+ for device in devices:
+ last_seen_user_agent = device["last_seen_user_agent"]
+ del device["last_seen_user_agent"]
+ if self._msc3852_enabled:
+ device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
return 200, {"devices": devices}
@@ -63,30 +81,34 @@ class DeleteDevicesRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
+ class PostBody(RequestBodyModel):
+ auth: Optional[AuthenticationData]
+ devices: List[StrictStr]
+
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
try:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PostBody)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
- # DELETE
+ # TODO: Can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
- body = {}
+ body = self.PostBody.parse_obj({})
else:
raise e
- assert_params_in_dict(body, ["devices"])
-
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"remove device(s) from your account",
# Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used.
@@ -94,7 +116,7 @@ class DeleteDevicesRestServlet(RestServlet):
)
await self.device_handler.delete_devices(
- requester.user.to_string(), body["devices"]
+ requester.user.to_string(), body.devices
)
return 200, {}
@@ -106,8 +128,11 @@ class DeviceRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
self.auth_handler = hs.get_auth_handler()
+ self._msc3852_enabled = hs.config.experimental.msc3852_enabled
async def on_GET(
self, request: SynapseRequest, device_id: str
@@ -118,8 +143,23 @@ class DeviceRestServlet(RestServlet):
)
if device is None:
raise NotFoundError("No device found")
+
+ # If MSC3852 is disabled, then the "last_seen_user_agent" field will be
+ # removed from each device. If it is enabled, then the field name will
+ # be replaced by the unstable identifier.
+ #
+ # When MSC3852 is accepted, this block of code can just be removed to
+ # expose "last_seen_user_agent" to clients.
+ last_seen_user_agent = device["last_seen_user_agent"]
+ del device["last_seen_user_agent"]
+ if self._msc3852_enabled:
+ device["org.matrix.msc3852.last_seen_user_agent"] = last_seen_user_agent
+
return 200, device
+ class DeleteBody(RequestBodyModel):
+ auth: Optional[AuthenticationData]
+
@interactive_auth_handler
async def on_DELETE(
self, request: SynapseRequest, device_id: str
@@ -127,20 +167,21 @@ class DeviceRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
try:
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.DeleteBody)
except errors.SynapseError as e:
if e.errcode == errors.Codes.NOT_JSON:
+ # TODO: can/should we remove this fallback now?
# deal with older clients which didn't pass a JSON dict
# the same as those that pass an empty dict
- body = {}
+ body = self.DeleteBody.parse_obj({})
else:
raise
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
- body,
+ body.dict(exclude_unset=True),
"remove a device from your account",
# Users might call this multiple times in a row while cleaning up
# devices, allow a single UI auth session to be re-used.
@@ -152,18 +193,33 @@ class DeviceRestServlet(RestServlet):
)
return 200, {}
+ class PutBody(RequestBodyModel):
+ display_name: Optional[StrictStr]
+
async def on_PUT(
self, request: SynapseRequest, device_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- body = parse_json_object_from_request(request)
+ body = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.device_handler.update_device(
- requester.user.to_string(), device_id, body
+ requester.user.to_string(), device_id, body.dict()
)
return 200, {}
+class DehydratedDeviceDataModel(RequestBodyModel):
+ """JSON blob describing a dehydrated device to be stored.
+
+ Expects other freeform fields. Use .dict() to access them.
+ """
+
+ class Config:
+ extra = Extra.allow
+
+ algorithm: StrictStr
+
+
class DehydratedDeviceServlet(RestServlet):
"""Retrieve or store a dehydrated device.
@@ -180,7 +236,7 @@ class DehydratedDeviceServlet(RestServlet):
}
}
- PUT /org.matrix.msc2697/dehydrated_device
+ PUT /org.matrix.msc2697.v2/dehydrated_device
Content-Type: application/json
{
@@ -205,7 +261,9 @@ class DehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -219,27 +277,18 @@ class DehydratedDeviceServlet(RestServlet):
else:
raise errors.NotFoundError("No dehydrated device available")
+ class PutBody(RequestBodyModel):
+ device_data: DehydratedDeviceDataModel
+ initial_device_display_name: Optional[StrictStr]
+
async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- submission = parse_json_object_from_request(request)
+ submission = parse_and_validate_json_object_from_request(request, self.PutBody)
requester = await self.auth.get_user_by_req(request)
- if "device_data" not in submission:
- raise errors.SynapseError(
- 400,
- "device_data missing",
- errcode=errors.Codes.MISSING_PARAM,
- )
- elif not isinstance(submission["device_data"], dict):
- raise errors.SynapseError(
- 400,
- "device_data must be an object",
- errcode=errors.Codes.INVALID_PARAM,
- )
-
device_id = await self.device_handler.store_dehydrated_device(
requester.user.to_string(),
- submission["device_data"],
- submission.get("initial_device_display_name", None),
+ submission.device_data.dict(),
+ submission.initial_device_display_name,
)
return 200, {"device_id": device_id}
@@ -271,30 +320,22 @@ class ClaimDehydratedDeviceServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
- self.device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.device_handler = handler
+
+ class PostBody(RequestBodyModel):
+ device_id: StrictStr
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- submission = parse_json_object_from_request(request)
-
- if "device_id" not in submission:
- raise errors.SynapseError(
- 400,
- "device_id missing",
- errcode=errors.Codes.MISSING_PARAM,
- )
- elif not isinstance(submission["device_id"], str):
- raise errors.SynapseError(
- 400,
- "device_id must be a string",
- errcode=errors.Codes.INVALID_PARAM,
- )
+ submission = parse_and_validate_json_object_from_request(request, self.PostBody)
result = await self.device_handler.rehydrate_device(
requester.user.to_string(),
self.auth.get_access_token_from_request(request),
- submission["device_id"],
+ submission.device_id,
)
return 200, result
diff --git a/synapse/rest/client/directory.py b/synapse/rest/client/directory.py
index bc1b18c92d..f17b4c8d22 100644
--- a/synapse/rest/client/directory.py
+++ b/synapse/rest/client/directory.py
@@ -13,15 +13,22 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from pydantic import StrictStr
+from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.servlet import (
+ RestServlet,
+ parse_and_validate_json_object_from_request,
+)
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
+from synapse.rest.models import RequestBodyModel
from synapse.types import JsonDict, RoomAlias
if TYPE_CHECKING:
@@ -54,6 +61,12 @@ class ClientDirectoryServer(RestServlet):
return 200, res
+ class PutBody(RequestBodyModel):
+ # TODO: get Pydantic to validate that this is a valid room id?
+ room_id: StrictStr
+ # `servers` is unspecced
+ servers: Optional[List[StrictStr]] = None
+
async def on_PUT(
self, request: SynapseRequest, room_alias: str
) -> Tuple[int, JsonDict]:
@@ -61,31 +74,22 @@ class ClientDirectoryServer(RestServlet):
raise SynapseError(400, "Room alias invalid", errcode=Codes.INVALID_PARAM)
room_alias_obj = RoomAlias.from_string(room_alias)
- content = parse_json_object_from_request(request)
- if "room_id" not in content:
- raise SynapseError(
- 400, 'Missing params: ["room_id"]', errcode=Codes.BAD_JSON
- )
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
logger.debug("Got content: %s", content)
logger.debug("Got room name: %s", room_alias_obj.to_string())
- room_id = content["room_id"]
- servers = content["servers"] if "servers" in content else None
-
- logger.debug("Got room_id: %s", room_id)
- logger.debug("Got servers: %s", servers)
+ logger.debug("Got room_id: %s", content.room_id)
+ logger.debug("Got servers: %s", content.servers)
- # TODO(erikj): Check types.
-
- room = await self.store.get_room(room_id)
+ room = await self.store.get_room(content.room_id)
if room is None:
raise SynapseError(400, "Room does not exist")
requester = await self.auth.get_user_by_req(request)
await self.directory_handler.create_association(
- requester, room_alias_obj, room_id, servers
+ requester, room_alias_obj, content.room_id, content.servers
)
return 200, {}
@@ -137,16 +141,18 @@ class ClientDirectoryListServer(RestServlet):
return 200, {"visibility": "public" if room["is_public"] else "private"}
+ class PutBody(RequestBodyModel):
+ visibility: Literal["public", "private"] = "public"
+
async def on_PUT(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
- content = parse_json_object_from_request(request)
- visibility = content.get("visibility", "public")
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
await self.directory_handler.edit_published_room_list(
- requester, room_id, visibility
+ requester, room_id, content.visibility
)
return 200, {}
@@ -163,12 +169,14 @@ class ClientAppserviceDirectoryListServer(RestServlet):
self.directory_handler = hs.get_directory_handler()
self.auth = hs.get_auth()
+ class PutBody(RequestBodyModel):
+ visibility: Literal["public", "private"] = "public"
+
async def on_PUT(
self, request: SynapseRequest, network_id: str, room_id: str
) -> Tuple[int, JsonDict]:
- content = parse_json_object_from_request(request)
- visibility = content.get("visibility", "public")
- return await self._edit(request, network_id, room_id, visibility)
+ content = parse_and_validate_json_object_from_request(request, self.PutBody)
+ return await self._edit(request, network_id, room_id, content.visibility)
async def on_DELETE(
self, request: SynapseRequest, network_id: str, room_id: str
@@ -176,7 +184,11 @@ class ClientAppserviceDirectoryListServer(RestServlet):
return await self._edit(request, network_id, room_id, "private")
async def _edit(
- self, request: SynapseRequest, network_id: str, room_id: str, visibility: str
+ self,
+ request: SynapseRequest,
+ network_id: str,
+ room_id: str,
+ visibility: Literal["public", "private"],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
if not requester.app_service:
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 916f5230f1..782e7d14e8 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -50,7 +50,9 @@ class EventStreamRestServlet(RestServlet):
raise SynapseError(400, "Guest users must specify room_id param")
room_id = parse_string(request, "room_id")
- pagin_config = await PaginationConfig.from_request(self.store, request)
+ pagin_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in args:
try:
diff --git a/synapse/rest/client/initial_sync.py b/synapse/rest/client/initial_sync.py
index cfadcb8e50..9b1bb8b521 100644
--- a/synapse/rest/client/initial_sync.py
+++ b/synapse/rest/client/initial_sync.py
@@ -39,7 +39,9 @@ class InitialSyncRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
args: Dict[bytes, List[bytes]] = request.args # type: ignore
as_client_event = b"raw" not in args
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index e3f454896a..ee038c7192 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -26,10 +26,11 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
+from synapse.logging.opentracing import log_kv, set_tag
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict, StreamToken
-
-from ._base import client_patterns, interactive_auth_handler
+from synapse.util.cancellation import cancellable
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -43,24 +44,48 @@ class KeyUploadServlet(RestServlet):
Content-Type: application/json
{
- "device_keys": {
- "user_id": "<user_id>",
- "device_id": "<device_id>",
- "valid_until_ts": <millisecond_timestamp>,
- "algorithms": [
- "m.olm.curve25519-aes-sha2",
- ]
- "keys": {
- "<algorithm>:<device_id>": "<key_base64>",
+ "device_keys": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ]
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures:" {
+ "<user_id>" {
+ "<algorithm>:<device_id>": "<signature_base64>"
+ }
+ }
},
- "signatures:" {
- "<user_id>" {
- "<algorithm>:<device_id>": "<signature_base64>"
- } } },
- "one_time_keys": {
- "<algorithm>:<key_id>": "<key_base64>"
- },
+ "fallback_keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ "signed_<algorithm>:<device_id>": {
+ "fallback": true,
+ "key": "<key_base64>",
+ "signatures": {
+ "<user_id>": {
+ "<algorithm>:<device_id>": "<key_base64>"
+ }
+ }
+ }
+ }
+ "one_time_keys": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ },
+ }
+
+ response, e.g.:
+
+ {
+ "one_time_key_counts": {
+ "curve25519": 10,
+ "signed_curve25519": 20
+ }
}
+
"""
PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
@@ -71,7 +96,13 @@ class KeyUploadServlet(RestServlet):
self.e2e_keys_handler = hs.get_e2e_keys_handler()
self.device_handler = hs.get_device_handler()
- @trace_with_opname("upload_keys")
+ if hs.config.worker.worker_app is None:
+ # if main process
+ self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
+ else:
+ # then a worker
+ self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+
async def on_POST(
self, request: SynapseRequest, device_id: Optional[str]
) -> Tuple[int, JsonDict]:
@@ -110,8 +141,8 @@ class KeyUploadServlet(RestServlet):
400, "To upload keys, you must pass device_id when authenticating"
)
- result = await self.e2e_keys_handler.upload_keys_for_user(
- user_id, device_id, body
+ result = await self.key_uploader(
+ user_id=user_id, device_id=device_id, keys=body
)
return 200, result
@@ -157,6 +188,7 @@ class KeyQueryServlet(RestServlet):
self.auth = hs.get_auth()
self.e2e_keys_handler = hs.get_e2e_keys_handler()
+ @cancellable
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
@@ -200,6 +232,7 @@ class KeyChangesServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastores().main
+ @cancellable
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0437c87d8d..8adced41e5 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -28,7 +28,14 @@ from typing import (
from typing_extensions import TypedDict
-from synapse.api.errors import Codes, InvalidClientTokenError, LoginError, SynapseError
+from synapse.api.constants import ApprovalNoticeMedium
+from synapse.api.errors import (
+ Codes,
+ InvalidClientTokenError,
+ LoginError,
+ NotApprovedError,
+ SynapseError,
+)
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.appservice import ApplicationService
@@ -55,11 +62,11 @@ logger = logging.getLogger(__name__)
class LoginResponse(TypedDict, total=False):
user_id: str
- access_token: str
+ access_token: Optional[str]
home_server: str
expires_in_ms: Optional[int]
refresh_token: Optional[str]
- device_id: str
+ device_id: Optional[str]
well_known: Optional[Dict[str, Any]]
@@ -92,6 +99,12 @@ class LoginRestServlet(RestServlet):
hs.config.registration.refreshable_access_token_lifetime is not None
)
+ # Whether we need to check if the user has been approved or not.
+ self._require_approval = (
+ hs.config.experimental.msc3866.enabled
+ and hs.config.experimental.msc3866.require_approval_for_new_accounts
+ )
+
self.auth = hs.get_auth()
self.clock = hs.get_clock()
@@ -220,6 +233,14 @@ class LoginRestServlet(RestServlet):
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
+ if self._require_approval:
+ approved = await self.auth_handler.is_user_approved(result["user_id"])
+ if not approved:
+ raise NotApprovedError(
+ msg="This account is pending approval by a server administrator.",
+ approval_notice_medium=ApprovalNoticeMedium.NONE,
+ )
+
well_known_data = self._well_known_builder.get_well_known()
if well_known_data:
result["well_known"] = well_known_data
@@ -329,7 +350,7 @@ class LoginRestServlet(RestServlet):
auth_provider_session_id: The session ID got during login from the SSO IdP.
Returns:
- result: Dictionary of account information after successful login.
+ Dictionary of account information after successful login.
"""
# Before we actually log them in we check if they've already logged in
@@ -356,6 +377,16 @@ class LoginRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM,
)
+ if self._require_approval:
+ approved = await self.auth_handler.is_user_approved(user_id)
+ if not approved:
+ # If the user isn't approved (and needs to be) we won't allow them to
+ # actually log in, so we don't want to create a device/access token.
+ return LoginResponse(
+ user_id=user_id,
+ home_server=self.hs.hostname,
+ )
+
initial_display_name = login_submission.get("initial_device_display_name")
(
device_id,
@@ -405,8 +436,7 @@ class LoginRestServlet(RestServlet):
The body of the JSON response.
"""
token = login_submission["token"]
- auth_handler = self.auth_handler
- res = await auth_handler.validate_short_term_login_token(token)
+ res = await self.auth_handler.consume_login_token(token)
return await self._complete_login(
res.user_id,
@@ -506,7 +536,7 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
class RefreshTokenServlet(RestServlet):
- PATTERNS = (re.compile("^/_matrix/client/v1/refresh$"),)
+ PATTERNS = client_patterns("/refresh$")
def __init__(self, hs: "HomeServer"):
self._auth_handler = hs.get_auth_handler()
diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py
new file mode 100644
index 0000000000..43ea21d5e6
--- /dev/null
+++ b/synapse/rest/client/login_token_request.py
@@ -0,0 +1,95 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING, Tuple
+
+from synapse.http.server import HttpServer
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class LoginTokenRequestServlet(RestServlet):
+ """
+ Get a token that can be used with `m.login.token` to log in a second device.
+
+ Request:
+
+ POST /login/token HTTP/1.1
+ Content-Type: application/json
+
+ {}
+
+ Response:
+
+ HTTP/1.1 200 OK
+ {
+ "login_token": "ABDEFGH",
+ "expires_in": 3600,
+ }
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3882/login/token$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self.clock = hs.get_clock()
+ self.server_name = hs.config.server.server_name
+ self.auth_handler = hs.get_auth_handler()
+ self.token_timeout = hs.config.experimental.msc3882_token_timeout
+ self.ui_auth = hs.config.experimental.msc3882_ui_auth
+
+ @interactive_auth_handler
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ body = parse_json_object_from_request(request)
+
+ if self.ui_auth:
+ await self.auth_handler.validate_user_via_ui_auth(
+ requester,
+ request,
+ body,
+ "issue a new access token for your account",
+ can_skip_ui_auth=False, # Don't allow skipping of UI auth
+ )
+
+ login_token = await self.auth_handler.create_login_token_for_user_id(
+ user_id=requester.user.to_string(),
+ auth_provider_id="org.matrix.msc3882.login_token_request",
+ duration_ms=self.token_timeout,
+ )
+
+ return (
+ 200,
+ {
+ "login_token": login_token,
+ "expires_in": self.token_timeout // 1000,
+ },
+ )
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ if hs.config.experimental.msc3882_enabled:
+ LoginTokenRequestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/logout.py b/synapse/rest/client/logout.py
index 23dfa4518f..6d34625ad5 100644
--- a/synapse/rest/client/logout.py
+++ b/synapse/rest/client/logout.py
@@ -15,6 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
+from synapse.handlers.device import DeviceHandler
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet
from synapse.http.site import SynapseRequest
@@ -34,7 +35,9 @@ class LogoutRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
@@ -59,7 +62,9 @@ class LogoutAllRestServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
- self._device_handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self._device_handler = handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_expired=True)
diff --git a/synapse/rest/client/models.py b/synapse/rest/client/models.py
new file mode 100644
index 0000000000..3d7940b0fc
--- /dev/null
+++ b/synapse/rest/client/models.py
@@ -0,0 +1,87 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import TYPE_CHECKING, Dict, Optional
+
+from pydantic import Extra, StrictInt, StrictStr, constr, validator
+
+from synapse.rest.models import RequestBodyModel
+from synapse.util.threepids import validate_email
+
+
+class AuthenticationData(RequestBodyModel):
+ """
+ Data used during user-interactive authentication.
+
+ (The name "Authentication Data" is taken directly from the spec.)
+
+ Additional keys will be present, depending on the `type` field. Use
+ `.dict(exclude_unset=True)` to access them.
+ """
+
+ class Config:
+ extra = Extra.allow
+
+ session: Optional[StrictStr] = None
+ type: Optional[StrictStr] = None
+
+
+if TYPE_CHECKING:
+ ClientSecretStr = StrictStr
+else:
+ # See also assert_valid_client_secret()
+ ClientSecretStr = constr(
+ regex="[0-9a-zA-Z.=_-]", # noqa: F722
+ min_length=1,
+ max_length=255,
+ strict=True,
+ )
+
+
+class ThreepidRequestTokenBody(RequestBodyModel):
+ client_secret: ClientSecretStr
+ id_server: Optional[StrictStr]
+ id_access_token: Optional[StrictStr]
+ next_link: Optional[StrictStr]
+ send_attempt: StrictInt
+
+ @validator("id_access_token", always=True)
+ def token_required_for_identity_server(
+ cls, token: Optional[str], values: Dict[str, object]
+ ) -> Optional[str]:
+ if values.get("id_server") is not None and token is None:
+ raise ValueError("id_access_token is required if an id_server is supplied.")
+ return token
+
+
+class EmailRequestTokenBody(ThreepidRequestTokenBody):
+ email: StrictStr
+
+ # Canonicalise the email address. The addresses are all stored canonicalised
+ # in the database. This allows the user to reset his password without having to
+ # know the exact spelling (eg. upper and lower case) of address in the database.
+ # Without this, an email stored in the database as "foo@bar.com" would cause
+ # user requests for "FOO@bar.com" to raise a Not Found error.
+ _email_validator = validator("email", allow_reuse=True)(validate_email)
+
+
+if TYPE_CHECKING:
+ ISO3116_1_Alpha_2 = StrictStr
+else:
+ # Per spec: two-letter uppercase ISO-3166-1-alpha-2
+ ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
+
+
+class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
+ country: ISO3116_1_Alpha_2
+ phone_number: StrictStr
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 24bc7c9095..61268e3af1 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -58,7 +58,11 @@ class NotificationsServlet(RestServlet):
)
receipts_by_room = await self.store.get_receipts_for_user_with_orderings(
- user_id, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
+ user_id,
+ [
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ],
)
notif_event_ids = [pa.event_id for pa in push_actions]
diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py
index c16d707909..e69fa0829d 100644
--- a/synapse/rest/client/profile.py
+++ b/synapse/rest/client/profile.py
@@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id)
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request)
@@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id)
- is_admin = await self.auth.is_server_admin(requester.user)
+ is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request)
try:
diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py
index 9a1f10f4be..975eef2144 100644
--- a/synapse/rest/client/pusher.py
+++ b/synapse/rest/client/pusher.py
@@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
+ self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -51,9 +52,16 @@ class PushersRestServlet(RestServlet):
user.to_string()
)
- filtered_pushers = [p.as_dict() for p in pushers]
+ pusher_dicts = [p.as_dict() for p in pushers]
- return 200, {"pushers": filtered_pushers}
+ for pusher in pusher_dicts:
+ if self._msc3881_enabled:
+ pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
+ pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
+ del pusher["enabled"]
+ del pusher["device_id"]
+
+ return 200, {"pushers": pusher_dicts}
class PushersSetRestServlet(RestServlet):
@@ -65,6 +73,7 @@ class PushersSetRestServlet(RestServlet):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
+ self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -103,6 +112,10 @@ class PushersSetRestServlet(RestServlet):
if "append" in content:
append = content["append"]
+ enabled = True
+ if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
+ enabled = content["org.matrix.msc3881.enabled"]
+
if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
@@ -111,7 +124,7 @@ class PushersSetRestServlet(RestServlet):
)
try:
- await self.pusher_pool.add_pusher(
+ await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content["kind"],
@@ -122,6 +135,8 @@ class PushersSetRestServlet(RestServlet):
lang=content["lang"],
data=content["data"],
profile_tag=content.get("profile_tag", ""),
+ enabled=enabled,
+ device_id=requester.device_id,
)
except PusherConfigException as pce:
raise SynapseError(
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 8896f2df50..852838515c 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,9 +40,11 @@ class ReadMarkerRestServlet(RestServlet):
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
- self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
- if hs.config.experimental.msc2285_enabled:
- self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+ self._known_receipt_types = {
+ ReceiptTypes.READ,
+ ReceiptTypes.FULLY_READ,
+ ReceiptTypes.READ_PRIVATE,
+ }
async def on_POST(
self, request: SynapseRequest, room_id: str
@@ -81,6 +83,8 @@ class ReadMarkerRestServlet(RestServlet):
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
+ # Setting the thread ID is not possible with the /read_markers endpoint.
+ thread_id=None,
)
return 200, {}
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 409bfd43c1..18a282b22c 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -15,8 +15,8 @@
import logging
from typing import TYPE_CHECKING, Tuple
-from synapse.api.constants import ReceiptTypes
-from synapse.api.errors import SynapseError
+from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes
+from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -43,12 +43,13 @@ class ReceiptRestServlet(RestServlet):
self.receipts_handler = hs.get_receipts_handler()
self.read_marker_handler = hs.get_read_marker_handler()
self.presence_handler = hs.get_presence_handler()
+ self._main_store = hs.get_datastores().main
- self._known_receipt_types = {ReceiptTypes.READ}
- if hs.config.experimental.msc2285_enabled:
- self._known_receipt_types.update(
- (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
- )
+ self._known_receipt_types = {
+ ReceiptTypes.READ,
+ ReceiptTypes.READ_PRIVATE,
+ ReceiptTypes.FULLY_READ,
+ }
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
@@ -61,7 +62,33 @@ class ReceiptRestServlet(RestServlet):
f"Receipt type must be {', '.join(self._known_receipt_types)}",
)
- parse_json_object_from_request(request, allow_empty_body=False)
+ body = parse_json_object_from_request(request)
+
+ # Pull the thread ID, if one exists.
+ thread_id = None
+ if "thread_id" in body:
+ thread_id = body.get("thread_id")
+ if not thread_id or not isinstance(thread_id, str):
+ raise SynapseError(
+ 400,
+ "thread_id field must be a non-empty string",
+ Codes.INVALID_PARAM,
+ )
+
+ if receipt_type == ReceiptTypes.FULLY_READ:
+ raise SynapseError(
+ 400,
+ f"thread_id is not compatible with {ReceiptTypes.FULLY_READ} receipts.",
+ Codes.INVALID_PARAM,
+ )
+
+ # Ensure the event ID roughly correlates to the thread ID.
+ if not await self._is_event_in_thread(event_id, thread_id):
+ raise SynapseError(
+ 400,
+ f"event_id {event_id} is not related to thread {thread_id}",
+ Codes.INVALID_PARAM,
+ )
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -77,10 +104,51 @@ class ReceiptRestServlet(RestServlet):
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
+ thread_id=thread_id,
)
return 200, {}
+ async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool:
+ """
+ The event must be related to the thread ID (in a vague sense) to ensure
+ clients aren't sending bogus receipts.
+
+ A thread ID is considered valid for a given event E if:
+
+ 1. E has a thread relation which matches the thread ID;
+ 2. E has another event which has a thread relation to E matching the
+ thread ID; or
+ 3. E is recursively related (via any rel_type) to an event which
+ satisfies 1 or 2.
+
+ Given the following DAG:
+
+ A <---[m.thread]-- B <--[m.annotation]-- C
+ ^
+ |--[m.reference]-- D <--[m.annotation]-- E
+
+ It is valid to send a receipt for thread A on A, B, C, D, or E.
+
+ It is valid to send a receipt for the main timeline on A, D, and E.
+
+ Args:
+ event_id: The event ID to check.
+ thread_id: The thread ID the event is potentially part of.
+
+ Returns:
+ True if the event belongs to the given thread, otherwise False.
+ """
+
+ # If the receipt is on the main timeline, it is enough to check whether
+ # the event is directly related to a thread.
+ if thread_id == MAIN_TIMELINE:
+ return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id)
+
+ # Otherwise, check if the event is directly part of a thread, or is the
+ # root message (or related to the root message) of a thread.
+ return thread_id == await self._main_store.get_thread_id_for_receipts(event_id)
+
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index 956c45e60a..de810ae3ec 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -21,17 +21,21 @@ from twisted.web.server import Request
import synapse
import synapse.api.auth
import synapse.types
-from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType
+from synapse.api.constants import (
+ APP_SERVICE_REGISTRATION_TYPE,
+ ApprovalNoticeMedium,
+ LoginType,
+)
from synapse.api.errors import (
Codes,
InteractiveAuthIncompleteError,
+ NotApprovedError,
SynapseError,
ThreepidValidationError,
UnrecognizedRequestError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.config import ConfigError
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import FederationRatelimitSettings
from synapse.config.server import is_threepid_reserved
@@ -74,7 +78,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.config = hs.config
- if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.hs.config.email.can_verify_email:
self.mailer = Mailer(
hs=self.hs,
app_name=self.config.email.email_app_name,
@@ -83,13 +87,10 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
- if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if (
- self.hs.config.email.local_threepid_handling_disabled_due_to_email_config
- ):
- logger.warning(
- "Email registration has been disabled due to lack of email config"
- )
+ if not self.hs.config.email.can_verify_email:
+ logger.warning(
+ "Email registration has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration has been disabled on this server"
)
@@ -138,35 +139,21 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
- assert self.hs.config.registration.account_threepid_delegate_email
-
- # Have the configured identity server handle the request
- ret = await self.identity_handler.request_email_token(
- self.hs.config.registration.account_threepid_delegate_email,
- email,
- client_secret,
- send_attempt,
- next_link,
- )
- else:
- # Send registration emails from Synapse,
- # wrapping the session id in a JSON object.
- ret = {
- "sid": await self.identity_handler.send_threepid_validation(
- email,
- client_secret,
- send_attempt,
- self.mailer.send_registration_mail,
- next_link,
- )
- }
+ # Send registration emails from Synapse
+ sid = await self.identity_handler.send_threepid_validation(
+ email,
+ client_secret,
+ send_attempt,
+ self.mailer.send_registration_mail,
+ next_link,
+ )
threepid_send_requests.labels(type="email", reason="register").observe(
send_attempt
)
- return 200, ret
+ # Wrap the session id in a JSON object
+ return 200, {"sid": sid}
class MsisdnRegisterRequestTokenRestServlet(RestServlet):
@@ -260,7 +247,7 @@ class RegistrationSubmitTokenServlet(RestServlet):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+ if self.config.email.can_verify_email:
self._failure_email_template = (
self.config.email.email_registration_template_failure_html
)
@@ -270,11 +257,10 @@ class RegistrationSubmitTokenServlet(RestServlet):
raise SynapseError(
400, "This medium is currently not supported for registration"
)
- if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF:
- if self.config.email.local_threepid_handling_disabled_due_to_email_config:
- logger.warning(
- "User registration via email has been disabled due to lack of email config"
- )
+ if not self.config.email.can_verify_email:
+ logger.warning(
+ "User registration via email has been disabled due to lack of email config"
+ )
raise SynapseError(
400, "Email-based registration is disabled on this server"
)
@@ -433,6 +419,11 @@ class RegisterRestServlet(RestServlet):
hs.config.registration.inhibit_user_in_use_error
)
+ self._require_approval = (
+ hs.config.experimental.msc3866.enabled
+ and hs.config.experimental.msc3866.require_approval_for_new_accounts
+ )
+
self._registration_flows = _calculate_registration_flows(
hs.config, self.auth_handler
)
@@ -484,9 +475,6 @@ class RegisterRestServlet(RestServlet):
"Appservice token must be provided when using a type of m.login.application_service",
)
- # Verify the AS
- self.auth.get_appservice_by_req(request)
-
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
@@ -756,6 +744,12 @@ class RegisterRestServlet(RestServlet):
access_token=return_dict.get("access_token"),
)
+ if self._require_approval:
+ raise NotApprovedError(
+ msg="This account needs to be approved by an administrator before it can be used.",
+ approval_notice_medium=ApprovalNoticeMedium.NONE,
+ )
+
return 200, return_dict
async def _do_appservice_registration(
@@ -800,7 +794,9 @@ class RegisterRestServlet(RestServlet):
"user_id": user_id,
"home_server": self.hs.hostname,
}
- if not params.get("inhibit_login", False):
+ # We don't want to log the user in if we're going to deny them access because
+ # they need to be approved first.
+ if not params.get("inhibit_login", False) and not self._require_approval:
device_id = params.get("device_id")
initial_display_name = params.get("initial_device_display_name")
(
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index ce97080013..9dd59196d9 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -13,13 +13,17 @@
# limitations under the License.
import logging
+import re
from typing import TYPE_CHECKING, Optional, Tuple
+from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.types import JsonDict, StreamToken
+from synapse.storage.databases.main.relations import ThreadsNextBatch
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -41,9 +45,8 @@ class RelationPaginationServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
- self.store = hs.get_datastores().main
+ self._store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
- self._msc3715_enabled = hs.config.experimental.msc3715_enabled
async def on_GET(
self,
@@ -55,37 +58,63 @@ class RelationPaginationServlet(RestServlet):
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ pagination_config = await PaginationConfig.from_request(
+ self._store, request, default_limit=5, default_dir="b"
+ )
+
+ # The unstable version of this API returns an extra field for client
+ # compatibility, see https://github.com/matrix-org/synapse/issues/12930.
+ assert request.path is not None
+ include_original_event = request.path.startswith(b"/_matrix/client/unstable/")
+
+ # Return the relations
+ result = await self._relations_handler.get_relations(
+ requester=requester,
+ event_id=parent_id,
+ room_id=room_id,
+ pagin_config=pagination_config,
+ include_original_event=include_original_event,
+ relation_type=relation_type,
+ event_type=event_type,
+ )
+
+ return 200, result
+
+
+class ThreadsServlet(RestServlet):
+ PATTERNS = (re.compile("^/_matrix/client/v1/rooms/(?P<room_id>[^/]*)/threads"),)
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastores().main
+ self._relations_handler = hs.get_relations_handler()
+
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+
limit = parse_integer(request, "limit", default=5)
- if self._msc3715_enabled:
- direction = parse_string(
- request,
- "org.matrix.msc3715.dir",
- default="b",
- allowed_values=["f", "b"],
- )
- else:
- direction = "b"
from_token_str = parse_string(request, "from")
- to_token_str = parse_string(request, "to")
+ include = parse_string(
+ request,
+ "include",
+ default=ThreadsListInclude.all.value,
+ allowed_values=[v.value for v in ThreadsListInclude],
+ )
# Return the relations
from_token = None
if from_token_str:
- from_token = await StreamToken.from_string(self.store, from_token_str)
- to_token = None
- if to_token_str:
- to_token = await StreamToken.from_string(self.store, to_token_str)
+ from_token = ThreadsNextBatch.from_string(from_token_str)
- result = await self._relations_handler.get_relations(
+ result = await self._relations_handler.get_threads(
requester=requester,
- event_id=parent_id,
room_id=room_id,
- relation_type=relation_type,
- event_type=event_type,
+ include=ThreadsListInclude(include),
limit=limit,
- direction=direction,
from_token=from_token,
- to_token=to_token,
)
return 200, result
@@ -93,3 +122,4 @@ class RelationPaginationServlet(RestServlet):
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
+ ThreadsServlet(hs).register(http_server)
diff --git a/synapse/rest/client/rendezvous.py b/synapse/rest/client/rendezvous.py
new file mode 100644
index 0000000000..89176b1ffa
--- /dev/null
+++ b/synapse/rest/client/rendezvous.py
@@ -0,0 +1,74 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from http.client import TEMPORARY_REDIRECT
+from typing import TYPE_CHECKING, Optional
+
+from synapse.http.server import HttpServer, respond_with_redirect
+from synapse.http.servlet import RestServlet
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class RendezvousServlet(RestServlet):
+ """
+ This is a placeholder implementation of [MSC3886](https://github.com/matrix-org/matrix-spec-proposals/pull/3886)
+ simple client rendezvous capability that is used by the "Sign in with QR" functionality.
+
+ This implementation only serves as a 307 redirect to a configured server rather than being a full implementation.
+
+ A module that implements the full functionality is available at: https://pypi.org/project/matrix-http-rendezvous-synapse/.
+
+ Request:
+
+ POST /rendezvous HTTP/1.1
+ Content-Type: ...
+
+ ...
+
+ Response:
+
+ HTTP/1.1 307
+ Location: <configured endpoint>
+ """
+
+ PATTERNS = client_patterns(
+ "/org.matrix.msc3886/rendezvous$", releases=[], v1=False, unstable=True
+ )
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ redirection_target: Optional[str] = hs.config.experimental.msc3886_endpoint
+ assert (
+ redirection_target is not None
+ ), "Servlet is only registered if there is a redirection target"
+ self.endpoint = redirection_target.encode("utf-8")
+
+ async def on_POST(self, request: SynapseRequest) -> None:
+ respond_with_redirect(
+ request, self.endpoint, statusCode=TEMPORARY_REDIRECT, cors=True
+ )
+
+ # PUT, GET and DELETE are not implemented as they should be fulfilled by the redirect target.
+
+
+def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
+ if hs.config.experimental.msc3886_endpoint is not None:
+ RendezvousServlet(hs).register(http_server)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 2f513164cb..91cb791139 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -16,9 +16,13 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
import re
+from enum import Enum
+from http import HTTPStatus
from typing import TYPE_CHECKING, Awaitable, Dict, List, Optional, Tuple
from urllib import parse as urlparse
+from prometheus_client.core import Histogram
+
from twisted.web.server import Request
from synapse import event_auth
@@ -34,7 +38,7 @@ from synapse.api.errors import (
)
from synapse.api.filtering import Filter
from synapse.events.utils import format_event_for_client_v2
-from synapse.http.server import HttpServer, cancellable
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
ResolveRoomIdMixin,
RestServlet,
@@ -46,13 +50,16 @@ from synapse.http.servlet import (
parse_strings_from_args,
)
from synapse.http.site import SynapseRequest
+from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
+from synapse.util.cancellation import cancellable
from synapse.util.stringutils import parse_and_validate_server_name, random_string
if TYPE_CHECKING:
@@ -61,6 +68,70 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+class _RoomSize(Enum):
+ """
+ Enum to differentiate sizes of rooms. This is a pretty good approximation
+ about how hard it will be to get events in the room. We could also look at
+ room "complexity".
+ """
+
+ # This doesn't necessarily mean the room is a DM, just that there is a DM
+ # amount of people there.
+ DM_SIZE = "direct_message_size"
+ SMALL = "small"
+ SUBSTANTIAL = "substantial"
+ LARGE = "large"
+
+ @staticmethod
+ def from_member_count(member_count: int) -> "_RoomSize":
+ if member_count <= 2:
+ return _RoomSize.DM_SIZE
+ elif member_count < 100:
+ return _RoomSize.SMALL
+ elif member_count < 1000:
+ return _RoomSize.SUBSTANTIAL
+ else:
+ return _RoomSize.LARGE
+
+
+# This is an extra metric on top of `synapse_http_server_response_time_seconds`
+# which times the same sort of thing but this one allows us to see values
+# greater than 10s. We use a separate dedicated histogram with its own buckets
+# so that we don't increase the cardinality of the general one because it's
+# multiplied across hundreds of servlets.
+messsages_response_timer = Histogram(
+ "synapse_room_message_list_rest_servlet_response_time_seconds",
+ "sec",
+ # We have a label for room size so we can try to see a more realistic
+ # picture of /messages response time for bigger rooms. We don't want the
+ # tiny rooms that can always respond fast skewing our results when we're trying
+ # to optimize the bigger cases.
+ ["room_size"],
+ buckets=(
+ 0.005,
+ 0.01,
+ 0.025,
+ 0.05,
+ 0.1,
+ 0.25,
+ 0.5,
+ 1.0,
+ 2.5,
+ 5.0,
+ 10.0,
+ 20.0,
+ 30.0,
+ 60.0,
+ 80.0,
+ 100.0,
+ 120.0,
+ 150.0,
+ 180.0,
+ "+Inf",
+ ),
+)
+
+
class TransactionRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
@@ -165,7 +236,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
msg_handler = self.message_handler
data = await msg_handler.get_room_data(
- user_id=requester.user.to_string(),
+ requester=requester,
room_id=room_id,
event_type=event_type,
state_key=state_key,
@@ -198,15 +269,9 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request)
- event_dict = {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- }
-
- if state_key is not None:
- event_dict["state_key"] = state_key
+ origin_server_ts = None
+ if requester.app_service:
+ origin_server_ts = parse_integer(request, "ts")
try:
if event_type == EventTypes.Member:
@@ -217,8 +282,22 @@ class RoomStateEventRestServlet(TransactionRestServlet):
room_id=room_id,
action=membership,
content=content,
+ origin_server_ts=origin_server_ts,
)
else:
+ event_dict: JsonDict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if state_key is not None:
+ event_dict["state_key"] = state_key
+
+ if origin_server_ts is not None:
+ event_dict["origin_server_ts"] = origin_server_ts
+
(
event,
_,
@@ -263,10 +342,10 @@ class RoomSendEventRestServlet(TransactionRestServlet):
"sender": requester.user.to_string(),
}
- # Twisted will have processed the args by now.
- assert request.args is not None
- if b"ts" in request.args and requester.app_service:
- event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
+ if requester.app_service:
+ origin_server_ts = parse_integer(request, "ts")
+ if origin_server_ts is not None:
+ event_dict["origin_server_ts"] = origin_server_ts
try:
(
@@ -510,7 +589,7 @@ class RoomMemberListRestServlet(RestServlet):
events = await handler.get_state_events(
room_id=room_id,
- user_id=requester.user.to_string(),
+ requester=requester,
at_token=at_token,
state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
)
@@ -556,6 +635,7 @@ class RoomMessageListRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
super().__init__()
self._hs = hs
+ self.clock = hs.get_clock()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -563,6 +643,18 @@ class RoomMessageListRestServlet(RestServlet):
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
+ processing_start_time = self.clock.time_msec()
+ # Fire off and hope that we get a result by the end.
+ #
+ # We're using the mypy type ignore comment because the `@cached`
+ # decorator on `get_number_joined_users_in_room` doesn't play well with
+ # the type system. Maybe in the future, it can use some ParamSpec
+ # wizardry to fix it up.
+ room_member_count_deferred = run_in_background( # type: ignore[call-arg]
+ self.store.get_number_joined_users_in_room,
+ room_id, # type: ignore[arg-type]
+ )
+
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
@@ -593,6 +685,12 @@ class RoomMessageListRestServlet(RestServlet):
event_filter=event_filter,
)
+ processing_end_time = self.clock.time_msec()
+ room_member_count = await make_deferred_yieldable(room_member_count_deferred)
+ messsages_response_timer.labels(
+ room_size=_RoomSize.from_member_count(room_member_count)
+ ).observe((processing_end_time - processing_start_time) / 1000)
+
return 200, msgs
@@ -613,8 +711,7 @@ class RoomStateRestServlet(RestServlet):
# Get all the current state for this room
events = await self.message_handler.get_state_events(
room_id=room_id,
- user_id=requester.user.to_string(),
- is_guest=requester.is_guest,
+ requester=requester,
)
return 200, events
@@ -633,7 +730,9 @@ class RoomInitialSyncRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True)
- pagination_config = await PaginationConfig.from_request(self.store, request)
+ pagination_config = await PaginationConfig.from_request(
+ self.store, request, default_limit=10
+ )
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)
@@ -672,7 +771,7 @@ class RoomEventServlet(RestServlet):
== "true"
)
if include_unredacted_content and not await self.auth.is_server_admin(
- requester.user
+ requester
):
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
@@ -860,7 +959,16 @@ class RoomMembershipRestServlet(TransactionRestServlet):
# cheekily send invalid bodies.
content = {}
- if membership_action == "invite" and self._has_3pid_invite_keys(content):
+ if membership_action == "invite" and all(
+ key in content for key in ("medium", "address")
+ ):
+ if not all(key in content for key in ("id_server", "id_access_token")):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "`id_server` and `id_access_token` are required when doing 3pid invite",
+ Codes.MISSING_PARAM,
+ )
+
try:
await self.room_member_handler.do_3pid_invite(
room_id,
@@ -870,7 +978,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
content["id_server"],
requester,
txn_id,
- content.get("id_access_token"),
+ content["id_access_token"],
)
except ShadowBanError:
# Pretend the request succeeded.
@@ -907,12 +1015,6 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def _has_3pid_invite_keys(self, content: JsonDict) -> bool:
- for key in {"id_server", "medium", "address"}:
- if key not in content:
- return False
- return True
-
def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
@@ -928,6 +1030,8 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.auth = hs.get_auth()
+ self._relation_handler = hs.get_relations_handler()
+ self._msc3912_enabled = hs.config.experimental.msc3912_enabled
def register(self, http_server: HttpServer) -> None:
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
@@ -944,20 +1048,46 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
content = parse_json_object_from_request(request)
try:
- (
- event,
- _,
- ) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
- {
- "type": EventTypes.Redaction,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- "redacts": event_id,
- },
- txn_id=txn_id,
- )
+ with_relations = None
+ if self._msc3912_enabled and "org.matrix.msc3912.with_relations" in content:
+ with_relations = content["org.matrix.msc3912.with_relations"]
+ del content["org.matrix.msc3912.with_relations"]
+
+ # Check if there's an existing event for this transaction now (even though
+ # create_and_send_nonmember_event also does it) because, if there's one,
+ # then we want to skip the call to redact_events_related_to.
+ event = None
+ if txn_id:
+ event = await self.event_creation_handler.get_event_from_transaction(
+ requester, txn_id, room_id
+ )
+
+ if event is None:
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
+ requester,
+ {
+ "type": EventTypes.Redaction,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ "redacts": event_id,
+ },
+ txn_id=txn_id,
+ )
+
+ if with_relations:
+ run_as_background_process(
+ "redact_related_events",
+ self._relation_handler.redact_events_related_to,
+ requester=requester,
+ event_id=event_id,
+ initial_redaction_event=event,
+ relation_types=with_relations,
+ )
+
event_id = event.event_id
except ShadowBanError:
event_id = "$" + random_string(43)
@@ -1177,9 +1307,7 @@ class TimestampLookupRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request)
- await self._auth.check_user_in_room_or_world_readable(
- room_id, requester.user.to_string()
- )
+ await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])
diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py
index dd91dabedd..10be4a781b 100644
--- a/synapse/rest/client/room_batch.py
+++ b/synapse/rest/client/room_batch.py
@@ -108,6 +108,13 @@ class RoomBatchSendEventRestServlet(RestServlet):
errcode=Codes.MISSING_PARAM,
)
+ if await self.store.is_partial_state_room(room_id):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Cannot insert history batches until we have fully joined the room",
+ errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
+ )
+
# Verify the batch_id_from_query corresponds to an actual insertion event
# and have the batch connected.
if batch_id_from_query:
diff --git a/synapse/rest/client/sendtodevice.py b/synapse/rest/client/sendtodevice.py
index 1a8e9a96d4..46a8b03829 100644
--- a/synapse/rest/client/sendtodevice.py
+++ b/synapse/rest/client/sendtodevice.py
@@ -19,7 +19,7 @@ from synapse.http import servlet
from synapse.http.server import HttpServer
from synapse.http.servlet import assert_params_in_dict, parse_json_object_from_request
from synapse.http.site import SynapseRequest
-from synapse.logging.opentracing import set_tag, trace_with_opname
+from synapse.logging.opentracing import set_tag
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import JsonDict
@@ -43,7 +43,6 @@ class SendToDeviceRestServlet(servlet.RestServlet):
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
- @trace_with_opname("sendToDevice")
def on_PUT(
self, request: SynapseRequest, message_type: str, txn_id: str
) -> Awaitable[Tuple[int, JsonDict]]:
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index c2989765ce..f2013faeb2 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -100,6 +100,7 @@ class SyncRestServlet(RestServlet):
self._server_notices_sender = hs.get_server_notices_sender()
self._event_serializer = hs.get_event_client_serializer()
self._msc2654_enabled = hs.config.experimental.msc2654_enabled
+ self._msc3773_enabled = hs.config.experimental.msc3773_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
# This will always be set by the time Twisted calls us.
@@ -145,12 +146,12 @@ class SyncRestServlet(RestServlet):
elif filter_id.startswith("{"):
try:
filter_object = json_decoder.decode(filter_id)
- set_timeline_upper_limit(
- filter_object, self.hs.config.server.filter_timeline_limit
- )
except Exception:
- raise SynapseError(400, "Invalid filter JSON")
+ raise SynapseError(400, "Invalid filter JSON", errcode=Codes.NOT_JSON)
self.filtering.check_valid_filter(filter_object)
+ set_timeline_upper_limit(
+ filter_object, self.hs.config.server.filter_timeline_limit
+ )
filter_collection = FilterCollection(self.hs, filter_object)
else:
try:
@@ -509,6 +510,12 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
+ if room.unread_thread_notifications:
+ result["unread_thread_notifications"] = room.unread_thread_notifications
+ if self._msc3773_enabled:
+ result[
+ "org.matrix.msc3773.unread_thread_notifications"
+ ] = room.unread_thread_notifications
result["summary"] = room.summary
if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 0366986755..180a11ef88 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -75,6 +75,8 @@ class VersionsRestServlet(RestServlet):
"r0.6.1",
"v1.1",
"v1.2",
+ "v1.3",
+ "v1.4",
],
# as per MSC1497:
"unstable_features": {
@@ -94,7 +96,7 @@ class VersionsRestServlet(RestServlet):
# Supports the busy presence state described in MSC3026.
"org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
# Supports receiving private read receipts as per MSC2285
- "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+ "org.matrix.msc2285.stable": True, # TODO: Remove when MSC2285 becomes a part of the spec
# Supports filtering of /publicRooms by room type as per MSC3827
"org.matrix.msc3827.stable": True,
# Adds support for importing historical messages as per MSC2716
@@ -103,8 +105,22 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3030": self.config.experimental.msc3030_enabled,
# Adds support for thread relations, per MSC3440.
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
+ # Support for thread read receipts & notification counts.
+ "org.matrix.msc3771": True,
+ "org.matrix.msc3773": self.config.experimental.msc3773_enabled,
# Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
+ # Adds support for login token requests as per MSC3882
+ "org.matrix.msc3882": self.config.experimental.msc3882_enabled,
+ # Adds support for remotely enabling/disabling pushers, as per MSC3881
+ "org.matrix.msc3881": self.config.experimental.msc3881_enabled,
+ # Adds support for filtering /messages by event relation.
+ "org.matrix.msc3874": self.config.experimental.msc3874_enabled,
+ # Adds support for simple HTTP rendezvous as per MSC3886
+ "org.matrix.msc3886": self.config.experimental.msc3886_endpoint
+ is not None,
+ # Adds support for relation-based redactions as per MSC3912.
+ "org.matrix.msc3912": self.config.experimental.msc3912_enabled,
},
},
)
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index 7f8c1de1ff..26403facb8 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -14,17 +14,20 @@
from typing import TYPE_CHECKING
-from twisted.web.resource import Resource
-
-from .local_key_resource import LocalKey
-from .remote_key_resource import RemoteKey
+from synapse.http.server import HttpServer, JsonResource
+from synapse.rest.key.v2.local_key_resource import LocalKey
+from synapse.rest.key.v2.remote_key_resource import RemoteKey
if TYPE_CHECKING:
from synapse.server import HomeServer
-class KeyApiV2Resource(Resource):
+class KeyResource(JsonResource):
def __init__(self, hs: "HomeServer"):
- Resource.__init__(self)
- self.putChild(b"server", LocalKey(hs))
- self.putChild(b"query", RemoteKey(hs))
+ super().__init__(hs, canonical_json=True)
+ self.register_servlets(self, hs)
+
+ @staticmethod
+ def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
+ LocalKey(hs).register(http_server)
+ RemoteKey(hs).register(http_server)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 0c9f042c84..d03e728d42 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,16 +13,15 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Optional
+import re
+from typing import TYPE_CHECKING, Optional, Tuple
-from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json
from unpaddedbase64 import encode_base64
-from twisted.web.resource import Resource
from twisted.web.server import Request
-from synapse.http.server import respond_with_json_bytes
+from synapse.http.servlet import RestServlet
from synapse.types import JsonDict
if TYPE_CHECKING:
@@ -31,7 +30,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class LocalKey(Resource):
+class LocalKey(RestServlet):
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
signature verification keys for this server::
@@ -61,18 +60,17 @@ class LocalKey(Resource):
}
"""
- isLeaf = True
+ PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P<key_id>[^/]*))?$"),)
def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
- Resource.__init__(self)
def update_response_body(self, time_now_msec: int) -> None:
refresh_interval = self.config.key.key_refresh_interval
self.valid_until_ts = int(time_now_msec + refresh_interval)
- self.response_body = encode_canonical_json(self.response_json_object())
+ self.response_body = self.response_json_object()
def response_json_object(self) -> JsonDict:
verify_keys = {}
@@ -99,9 +97,11 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object
- def render_GET(self, request: Request) -> Optional[int]:
+ def on_GET(
+ self, request: Request, key_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
self.update_response_body(time_now)
- return respond_with_json_bytes(request, 200, self.response_body)
+ return 200, self.response_body
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f597157581..19820886f5 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,15 +13,20 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Set
+import re
+from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
from signedjson.sign import sign_json
-from synapse.api.errors import Codes, SynapseError
+from twisted.web.server import Request
+
from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.servlet import parse_integer, parse_json_object_from_request
-from synapse.http.site import SynapseRequest
+from synapse.http.server import HttpServer
+from synapse.http.servlet import (
+ RestServlet,
+ parse_integer,
+ parse_json_object_from_request,
+)
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import yieldable_gather_results
@@ -32,7 +37,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class RemoteKey(DirectServeJsonResource):
+class RemoteKey(RestServlet):
"""HTTP resource for retrieving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
}
"""
- isLeaf = True
-
def __init__(self, hs: "HomeServer"):
- super().__init__()
-
self.fetcher = ServerKeyFetcher(hs)
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -101,47 +102,52 @@ class RemoteKey(DirectServeJsonResource):
)
self.config = hs.config
- async def _async_render_GET(self, request: SynapseRequest) -> None:
- assert request.postpath is not None
- if len(request.postpath) == 1:
- (server,) = request.postpath
- query: dict = {server.decode("ascii"): {}}
- elif len(request.postpath) == 2:
- server, key_id = request.postpath
+ def register(self, http_server: HttpServer) -> None:
+ http_server.register_paths(
+ "GET",
+ (
+ re.compile(
+ "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
+ ),
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
+ http_server.register_paths(
+ "POST",
+ (re.compile("^/_matrix/key/v2/query$"),),
+ self.on_POST,
+ self.__class__.__name__,
+ )
+
+ async def on_GET(
+ self, request: Request, server: str, key_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
+ if server and key_id:
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
arguments = {}
if minimum_valid_until_ts is not None:
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
- query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
+ query = {server: {key_id: arguments}}
else:
- raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
+ query = {server: {}}
- await self.query_keys(request, query, query_remote_on_cache_miss=True)
+ return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
- async def _async_render_POST(self, request: SynapseRequest) -> None:
+ async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
query = content["server_keys"]
- await self.query_keys(request, query, query_remote_on_cache_miss=True)
+ return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
async def query_keys(
- self,
- request: SynapseRequest,
- query: JsonDict,
- query_remote_on_cache_miss: bool = False,
- ) -> None:
+ self, query: JsonDict, query_remote_on_cache_miss: bool = False
+ ) -> JsonDict:
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
- if (
- self.federation_domain_whitelist is not None
- and server_name not in self.federation_domain_whitelist
- ):
- logger.debug("Federation denied with %s", server_name)
- continue
-
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
@@ -153,21 +159,28 @@ class RemoteKey(DirectServeJsonResource):
time_now_ms = self.clock.time_msec()
- # Note that the value is unused.
+ # Map server_name->key_id->int. Note that the value of the init is unused.
+ # XXX: why don't we just use a set?
cache_misses: Dict[str, Dict[str, int]] = {}
for (server_name, key_id, _), key_results in cached.items():
results = [(result["ts_added_ms"], result) for result in key_results]
- if not results and key_id is not None:
- cache_misses.setdefault(server_name, {})[key_id] = 0
+ if key_id is None:
+ # all keys were requested. Just return what we have without worrying
+ # about validity
+ for _, result in results:
+ # Cast to bytes since postgresql returns a memoryview.
+ json_results.add(bytes(result["key_json"]))
continue
- if key_id is not None:
+ miss = False
+ if not results:
+ miss = True
+ else:
ts_added_ms, most_recent_result = max(results)
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
req_key = query.get(server_name, {}).get(key_id, {})
req_valid_until = req_key.get("minimum_valid_until_ts")
- miss = False
if req_valid_until is not None:
if ts_valid_until_ms < req_valid_until:
logger.debug(
@@ -211,19 +224,20 @@ class RemoteKey(DirectServeJsonResource):
ts_valid_until_ms,
time_now_ms,
)
-
- if miss:
- cache_misses.setdefault(server_name, {})[key_id] = 0
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(most_recent_result["key_json"]))
- else:
- for _, result in results:
- # Cast to bytes since postgresql returns a memoryview.
- json_results.add(bytes(result["key_json"]))
+
+ if miss and query_remote_on_cache_miss:
+ # only bother attempting to fetch keys from servers on our whitelist
+ if (
+ self.federation_domain_whitelist is None
+ or server_name in self.federation_domain_whitelist
+ ):
+ cache_misses.setdefault(server_name, {})[key_id] = 0
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
- if cache_misses and query_remote_on_cache_miss:
+ if cache_misses:
await yieldable_gather_results(
lambda t: self.fetcher.get_keys(*t),
(
@@ -231,7 +245,7 @@ class RemoteKey(DirectServeJsonResource):
for server_name, keys in cache_misses.items()
),
)
- await self.query_keys(request, query, query_remote_on_cache_miss=False)
+ return await self.query_keys(query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json_raw in json_results:
@@ -243,6 +257,4 @@ class RemoteKey(DirectServeJsonResource):
signed_keys.append(key_json)
- response = {"server_keys": signed_keys}
-
- respond_with_json(request, 200, response, canonical_json=True)
+ return {"server_keys": signed_keys}
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index c35d42fab8..d30878f704 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -254,30 +254,32 @@ async def respond_with_responder(
file_size: Size in bytes of the media. If not known it should be None
upload_name: The name of the requested file, if any.
"""
- if request._disconnected:
- logger.warning(
- "Not sending response to request %s, already disconnected.", request
- )
- return
-
if not responder:
respond_404(request)
return
- logger.debug("Responding to media request with responder %s", responder)
- add_file_headers(request, media_type, file_size, upload_name)
- try:
- with responder:
+ # If we have a responder we *must* use it as a context manager.
+ with responder:
+ if request._disconnected:
+ logger.warning(
+ "Not sending response to request %s, already disconnected.", request
+ )
+ return
+
+ logger.debug("Responding to media request with responder %s", responder)
+ add_file_headers(request, media_type, file_size, upload_name)
+ try:
+
await responder.write_to_consumer(request)
- except Exception as e:
- # The majority of the time this will be due to the client having gone
- # away. Unfortunately, Twisted simply throws a generic exception at us
- # in that case.
- logger.warning("Failed to write to consumer: %s %s", type(e), e)
-
- # Unregister the producer, if it has one, so Twisted doesn't complain
- if request.producer:
- request.unregisterProducer()
+ except Exception as e:
+ # The majority of the time this will be due to the client having gone
+ # away. Unfortunately, Twisted simply throws a generic exception at us
+ # in that case.
+ logger.warning("Failed to write to consumer: %s %s", type(e), e)
+
+ # Unregister the producer, if it has one, so Twisted doesn't complain
+ if request.producer:
+ request.unregisterProducer()
finish_request(request)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 7435fd9130..40b0d39eb2 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -19,6 +19,8 @@ import shutil
from io import BytesIO
from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
+from matrix_common.types.mxc_uri import MXCUri
+
import twisted.internet.error
import twisted.web.http
from twisted.internet.defer import Deferred
@@ -64,7 +66,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-
# How often to run the background job to update the "recently accessed"
# attribute of local and remote media.
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 # 1 minute
@@ -187,7 +188,7 @@ class MediaRepository:
content: IO,
content_length: int,
auth_user: UserID,
- ) -> str:
+ ) -> MXCUri:
"""Store uploaded content for a local user and return the mxc URL
Args:
@@ -220,7 +221,7 @@ class MediaRepository:
await self._generate_thumbnails(None, media_id, media_id, media_type)
- return "mxc://%s/%s" % (self.server_name, media_id)
+ return MXCUri(self.server_name, media_id)
async def get_local_media(
self, request: SynapseRequest, media_id: str, name: Optional[str]
@@ -343,8 +344,8 @@ class MediaRepository:
download from remote server.
Args:
- server_name (str): Remote server_name where the media originated.
- media_id (str): The media ID of the content (as defined by the
+ server_name: Remote server_name where the media originated.
+ media_id: The media ID of the content (as defined by the
remote server).
Returns:
diff --git a/synapse/rest/media/v1/oembed.py b/synapse/rest/media/v1/oembed.py
index 2177b46c9e..827afd868d 100644
--- a/synapse/rest/media/v1/oembed.py
+++ b/synapse/rest/media/v1/oembed.py
@@ -139,65 +139,72 @@ class OEmbedProvider:
try:
# oEmbed responses *must* be UTF-8 according to the spec.
oembed = json_decoder.decode(raw_body.decode("utf-8"))
+ except ValueError:
+ return OEmbedResult({}, None, None)
- # The version is a required string field, but not always provided,
- # or sometimes provided as a float. Be lenient.
- oembed_version = oembed.get("version", "1.0")
- if oembed_version != "1.0" and oembed_version != 1:
- raise RuntimeError(f"Invalid oEmbed version: {oembed_version}")
+ # The version is a required string field, but not always provided,
+ # or sometimes provided as a float. Be lenient.
+ oembed_version = oembed.get("version", "1.0")
+ if oembed_version != "1.0" and oembed_version != 1:
+ return OEmbedResult({}, None, None)
- # Ensure the cache age is None or an int.
- cache_age = oembed.get("cache_age")
- if cache_age:
- cache_age = int(cache_age) * 1000
-
- # The results.
- open_graph_response = {
- "og:url": url,
- }
-
- title = oembed.get("title")
- if title:
- open_graph_response["og:title"] = title
-
- author_name = oembed.get("author_name")
+ # Attempt to parse the cache age, if possible.
+ try:
+ cache_age = int(oembed.get("cache_age")) * 1000
+ except (TypeError, ValueError):
+ # If the cache age cannot be parsed (e.g. wrong type or invalid
+ # string), ignore it.
+ cache_age = None
- # Use the provider name and as the site.
- provider_name = oembed.get("provider_name")
- if provider_name:
- open_graph_response["og:site_name"] = provider_name
+ # The oEmbed response converted to Open Graph.
+ open_graph_response: JsonDict = {"og:url": url}
- # If a thumbnail exists, use it. Note that dimensions will be calculated later.
- if "thumbnail_url" in oembed:
- open_graph_response["og:image"] = oembed["thumbnail_url"]
+ title = oembed.get("title")
+ if title and isinstance(title, str):
+ open_graph_response["og:title"] = title
- # Process each type separately.
- oembed_type = oembed["type"]
- if oembed_type == "rich":
- calc_description_and_urls(open_graph_response, oembed["html"])
-
- elif oembed_type == "photo":
- # If this is a photo, use the full image, not the thumbnail.
- open_graph_response["og:image"] = oembed["url"]
+ author_name = oembed.get("author_name")
+ if not isinstance(author_name, str):
+ author_name = None
- elif oembed_type == "video":
- open_graph_response["og:type"] = "video.other"
+ # Use the provider name and as the site.
+ provider_name = oembed.get("provider_name")
+ if provider_name and isinstance(provider_name, str):
+ open_graph_response["og:site_name"] = provider_name
+
+ # If a thumbnail exists, use it. Note that dimensions will be calculated later.
+ thumbnail_url = oembed.get("thumbnail_url")
+ if thumbnail_url and isinstance(thumbnail_url, str):
+ open_graph_response["og:image"] = thumbnail_url
+
+ # Process each type separately.
+ oembed_type = oembed.get("type")
+ if oembed_type == "rich":
+ html = oembed.get("html")
+ if isinstance(html, str):
+ calc_description_and_urls(open_graph_response, html)
+
+ elif oembed_type == "photo":
+ # If this is a photo, use the full image, not the thumbnail.
+ url = oembed.get("url")
+ if url and isinstance(url, str):
+ open_graph_response["og:image"] = url
+
+ elif oembed_type == "video":
+ open_graph_response["og:type"] = "video.other"
+ html = oembed.get("html")
+ if html and isinstance(html, str):
calc_description_and_urls(open_graph_response, oembed["html"])
- open_graph_response["og:video:width"] = oembed["width"]
- open_graph_response["og:video:height"] = oembed["height"]
-
- elif oembed_type == "link":
- open_graph_response["og:type"] = "website"
+ for size in ("width", "height"):
+ val = oembed.get(size)
+ if val is not None and isinstance(val, int):
+ open_graph_response[f"og:video:{size}"] = val
- else:
- raise RuntimeError(f"Unknown oEmbed type: {oembed_type}")
+ elif oembed_type == "link":
+ open_graph_response["og:type"] = "website"
- except Exception as e:
- # Trap any exception and let the code follow as usual.
- logger.warning("Error parsing oEmbed metadata from %s: %r", url, e)
- open_graph_response = {}
- author_name = None
- cache_age = None
+ else:
+ logger.warning("Unknown oEmbed type: %s", oembed_type)
return OEmbedResult(open_graph_response, author_name, cache_age)
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index b36c98a08e..a8f6fd6b35 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -732,10 +732,6 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("Running url preview cache expiry")
- if not (await self.store.db_pool.updates.has_completed_background_updates()):
- logger.debug("Still running DB updates; skipping url preview cache expiry")
- return
-
def try_remove_parent_dirs(dirs: Iterable[str]) -> None:
"""Attempt to remove the given chain of parent directories
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 9b93b9b4f6..a48a4de92a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -138,7 +138,7 @@ class Thumbnailer:
"""Rescales the image to the given dimensions.
Returns:
- BytesIO: the bytes of the encoded image ready to be written to disk
+ The bytes of the encoded image ready to be written to disk
"""
with self._resize(width, height) as scaled:
return self._encode_image(scaled, output_type)
@@ -155,7 +155,7 @@ class Thumbnailer:
max_height: The largest possible height.
Returns:
- BytesIO: the bytes of the encoded image ready to be written to disk
+ The bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_width = width
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index e73e431dc9..97548b54e5 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -101,6 +101,8 @@ class UploadResource(DirectServeJsonResource):
# the default 404, as that would just be confusing.
raise SynapseError(400, "Bad content")
- logger.info("Uploaded content with URI %r", content_uri)
+ logger.info("Uploaded content with URI '%s'", content_uri)
- respond_with_json(request, 200, {"content_uri": content_uri}, send_cors=True)
+ respond_with_json(
+ request, 200, {"content_uri": str(content_uri)}, send_cors=True
+ )
diff --git a/synapse/rest/models.py b/synapse/rest/models.py
new file mode 100644
index 0000000000..ac39cda8e5
--- /dev/null
+++ b/synapse/rest/models.py
@@ -0,0 +1,23 @@
+from pydantic import BaseModel, Extra
+
+
+class RequestBodyModel(BaseModel):
+ """A custom version of Pydantic's BaseModel which
+
+ - ignores unknown fields and
+ - does not allow fields to be overwritten after construction,
+
+ but otherwise uses Pydantic's default behaviour.
+
+ Ignoring unknown fields is a useful default. It means that clients can provide
+ unstable field not known to the server without the request being refused outright.
+
+ Subclassing in this way is recommended by
+ https://pydantic-docs.helpmanual.io/usage/model_config/#change-behaviour-globally
+ """
+
+ class Config:
+ # By default, ignore fields that we don't recognise.
+ extra = Extra.ignore
+ # By default, don't allow fields to be reassigned after parsing.
+ allow_mutation = False
diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py
index 1c1c7b3613..22784157e6 100644
--- a/synapse/rest/synapse/client/new_user_consent.py
+++ b/synapse/rest/synapse/client/new_user_consent.py
@@ -20,6 +20,7 @@ from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
from synapse.http.server import DirectServeHtmlResource, respond_with_html
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.types import UserID
from synapse.util.templates import build_jinja_env
@@ -88,7 +89,7 @@ class NewUserConsentResource(DirectServeHtmlResource):
html = template.render(template_params)
respond_with_html(request, 200, html)
- async def _async_render_POST(self, request: Request) -> None:
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
try:
session_id = get_username_mapping_session_cookie_from_request(request)
except SynapseError as e:
diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py
index 81fec39659..e4b28ce3df 100644
--- a/synapse/rest/synapse/client/oidc/__init__.py
+++ b/synapse/rest/synapse/client/oidc/__init__.py
@@ -17,6 +17,9 @@ from typing import TYPE_CHECKING
from twisted.web.resource import Resource
+from synapse.rest.synapse.client.oidc.backchannel_logout_resource import (
+ OIDCBackchannelLogoutResource,
+)
from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource
if TYPE_CHECKING:
@@ -29,6 +32,7 @@ class OIDCResource(Resource):
def __init__(self, hs: "HomeServer"):
Resource.__init__(self)
self.putChild(b"callback", OIDCCallbackResource(hs))
+ self.putChild(b"backchannel_logout", OIDCBackchannelLogoutResource(hs))
__all__ = ["OIDCResource"]
diff --git a/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py
new file mode 100644
index 0000000000..e07e76855a
--- /dev/null
+++ b/synapse/rest/synapse/client/oidc/backchannel_logout_resource.py
@@ -0,0 +1,35 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import TYPE_CHECKING
+
+from synapse.http.server import DirectServeJsonResource
+from synapse.http.site import SynapseRequest
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class OIDCBackchannelLogoutResource(DirectServeJsonResource):
+ isLeaf = 1
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__()
+ self._oidc_handler = hs.get_oidc_handler()
+
+ async def _async_render_POST(self, request: SynapseRequest) -> None:
+ await self._oidc_handler.handle_backchannel_logout(request)
diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py
index 6ac9dbc7c9..b9402cfb75 100644
--- a/synapse/rest/synapse/client/password_reset.py
+++ b/synapse/rest/synapse/client/password_reset.py
@@ -17,7 +17,6 @@ from typing import TYPE_CHECKING, Tuple
from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
-from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.server import DirectServeHtmlResource
from synapse.http.servlet import parse_string
from synapse.util.stringutils import assert_valid_client_secret
@@ -46,9 +45,6 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
- self._local_threepid_handling_disabled_due_to_email_config = (
- hs.config.email.local_threepid_handling_disabled_due_to_email_config
- )
self._confirmation_email_template = (
hs.config.email.email_password_reset_template_confirmation_html
)
@@ -59,8 +55,8 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource):
hs.config.email.email_password_reset_template_failure_html
)
- # This resource should not be mounted if threepid behaviour is not LOCAL
- assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
+ # This resource should only be mounted if email validation is enabled
+ assert hs.config.email.can_verify_email
async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]:
sid = parse_string(request, "sid", required=True)
diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 6f7ac54c65..e2174fdfea 100644
--- a/synapse/rest/well_known.py
+++ b/synapse/rest/well_known.py
@@ -18,6 +18,7 @@ from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.http.server import set_cors_headers
+from synapse.http.site import SynapseRequest
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.stringutils import parse_server_name
@@ -63,7 +64,7 @@ class ClientWellKnownResource(Resource):
Resource.__init__(self)
self._well_known_builder = WellKnownBuilder(hs)
- def render_GET(self, request: Request) -> bytes:
+ def render_GET(self, request: SynapseRequest) -> bytes:
set_cors_headers(request)
r = self._well_known_builder.get_well_known()
if not r:
|