summary refs log tree commit diff
path: root/synapse/rest
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest')
-rw-r--r--synapse/rest/__init__.py10
-rw-r--r--synapse/rest/admin/__init__.py11
-rw-r--r--synapse/rest/admin/_base.py4
-rw-r--r--synapse/rest/admin/devices.py17
-rw-r--r--synapse/rest/admin/event_reports.py88
-rw-r--r--synapse/rest/admin/purge_room_servlet.py5
-rw-r--r--synapse/rest/admin/rooms.py16
-rw-r--r--synapse/rest/admin/server_notice_servlet.py9
-rw-r--r--synapse/rest/admin/users.py38
-rw-r--r--synapse/rest/client/transactions.py2
-rw-r--r--synapse/rest/client/v1/directory.py8
-rw-r--r--synapse/rest/client/v1/events.py7
-rw-r--r--synapse/rest/client/v1/initial_sync.py5
-rw-r--r--synapse/rest/client/v1/login.py133
-rw-r--r--synapse/rest/client/v1/logout.py4
-rw-r--r--synapse/rest/client/v1/presence.py2
-rw-r--r--synapse/rest/client/v1/profile.py6
-rw-r--r--synapse/rest/client/v1/push_rule.py37
-rw-r--r--synapse/rest/client/v1/pusher.py6
-rw-r--r--synapse/rest/client/v1/room.py226
-rw-r--r--synapse/rest/client/v1/voip.py2
-rw-r--r--synapse/rest/client/v2_alpha/account.py288
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py4
-rw-r--r--synapse/rest/client/v2_alpha/account_validity.py4
-rw-r--r--synapse/rest/client/v2_alpha/auth.py138
-rw-r--r--synapse/rest/client/v2_alpha/capabilities.py2
-rw-r--r--synapse/rest/client/v2_alpha/devices.py6
-rw-r--r--synapse/rest/client/v2_alpha/filter.py4
-rw-r--r--synapse/rest/client/v2_alpha/groups.py52
-rw-r--r--synapse/rest/client/v2_alpha/keys.py15
-rw-r--r--synapse/rest/client/v2_alpha/notifications.py2
-rw-r--r--synapse/rest/client/v2_alpha/openid.py2
-rw-r--r--synapse/rest/client/v2_alpha/password_policy.py2
-rw-r--r--synapse/rest/client/v2_alpha/read_marker.py2
-rw-r--r--synapse/rest/client/v2_alpha/receipts.py2
-rw-r--r--synapse/rest/client/v2_alpha/register.py178
-rw-r--r--synapse/rest/client/v2_alpha/relations.py26
-rw-r--r--synapse/rest/client/v2_alpha/report_event.py2
-rw-r--r--synapse/rest/client/v2_alpha/room_keys.py6
-rw-r--r--synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py16
-rw-r--r--synapse/rest/client/v2_alpha/sendtodevice.py2
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py68
-rw-r--r--synapse/rest/client/v2_alpha/sync.py18
-rw-r--r--synapse/rest/client/v2_alpha/tags.py4
-rw-r--r--synapse/rest/client/v2_alpha/thirdparty.py8
-rw-r--r--synapse/rest/client/v2_alpha/tokenrefresh.py2
-rw-r--r--synapse/rest/client/v2_alpha/user_directory.py2
-rw-r--r--synapse/rest/client/versions.py23
-rw-r--r--synapse/rest/consent/consent_resource.py4
-rw-r--r--synapse/rest/health.py31
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py14
-rw-r--r--synapse/rest/media/v1/_base.py12
-rw-r--r--synapse/rest/media/v1/filepath.py21
-rw-r--r--synapse/rest/media/v1/media_repository.py177
-rw-r--r--synapse/rest/media/v1/media_storage.py106
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py35
-rw-r--r--synapse/rest/media/v1/storage_provider.py77
-rw-r--r--synapse/rest/media/v1/thumbnail_resource.py5
-rw-r--r--synapse/rest/media/v1/thumbnailer.py20
-rw-r--r--synapse/rest/media/v1/upload_resource.py4
-rw-r--r--synapse/rest/saml2/response_resource.py16
-rw-r--r--synapse/rest/synapse/__init__.py14
-rw-r--r--synapse/rest/synapse/client/__init__.py14
-rw-r--r--synapse/rest/synapse/client/password_reset.py127
-rw-r--r--synapse/rest/well_known.py6
65 files changed, 1339 insertions, 858 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py

index 2e81eeff65..5d9cdf4bde 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py
@@ -13,8 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import synapse.rest.admin from synapse.http.server import JsonResource +from synapse.rest import admin from synapse.rest.client import versions from synapse.rest.client.v1 import ( directory, @@ -50,6 +50,7 @@ from synapse.rest.client.v2_alpha import ( room_keys, room_upgrade_rest_servlet, sendtodevice, + shared_rooms, sync, tags, thirdparty, @@ -123,6 +124,7 @@ class ClientRestResource(JsonResource): password_policy.register_servlets(hs, client_resource) # moving to /_synapse/admin - synapse.rest.admin.register_servlets_for_client_rest_resource( - hs, client_resource - ) + admin.register_servlets_for_client_rest_resource(hs, client_resource) + + # unstable + shared_rooms.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 1c88c93f38..57cac22252 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py
@@ -16,13 +16,13 @@ import logging import platform -import re import synapse from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.http.server import JsonResource from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, historical_admin_path_patterns, ) @@ -31,6 +31,7 @@ from synapse.rest.admin.devices import ( DeviceRestServlet, DevicesRestServlet, ) +from synapse.rest.admin.event_reports import EventReportsRestServlet from synapse.rest.admin.groups import DeleteGroupAdminRestServlet from synapse.rest.admin.media import ListMediaInRoom, register_servlets_for_media_repo from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet @@ -49,6 +50,7 @@ from synapse.rest.admin.users import ( ResetPasswordRestServlet, SearchUsersRestServlet, UserAdminServlet, + UserMembershipRestServlet, UserRegisterServlet, UserRestServletV2, UsersRestServlet, @@ -61,7 +63,7 @@ logger = logging.getLogger(__name__) class VersionServlet(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),) + PATTERNS = admin_patterns("/server_version$") def __init__(self, hs): self.res = { @@ -107,7 +109,8 @@ class PurgeHistoryRestServlet(RestServlet): if event.room_id != room_id: raise SynapseError(400, "Event is for wrong room.") - token = await self.store.get_topological_token_for_event(event_id) + room_token = await self.store.get_topological_token_for_event(event_id) + token = await room_token.to_string(self.store) logger.info("[purge] purging up to token %s (event_id %s)", token, event_id) elif "purge_up_to_ts" in body: @@ -209,11 +212,13 @@ def register_servlets(hs, http_server): SendServerNoticeServlet(hs).register(http_server) VersionServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server) + UserMembershipRestServlet(hs).register(http_server) UserRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server) DeviceRestServlet(hs).register(http_server) DevicesRestServlet(hs).register(http_server) DeleteDevicesRestServlet(hs).register(http_server) + EventReportsRestServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index d82eaf5e38..db9fea263a 100644 --- a/synapse/rest/admin/_base.py +++ b/synapse/rest/admin/_base.py
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex): ] -def admin_patterns(path_regex: str): +def admin_patterns(path_regex: str, version: str = "v1"): """Returns the list of patterns for an admin endpoint Args: @@ -54,7 +54,7 @@ def admin_patterns(path_regex: str): Returns: A list of regex patterns. """ - admin_prefix = "^/_synapse/admin/v1" + admin_prefix = "^/_synapse/admin/" + version patterns = [re.compile(admin_prefix + path_regex)] return patterns diff --git a/synapse/rest/admin/devices.py b/synapse/rest/admin/devices.py
index 8d32677339..a163863322 100644 --- a/synapse/rest/admin/devices.py +++ b/synapse/rest/admin/devices.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import re from synapse.api.errors import NotFoundError, SynapseError from synapse.http.servlet import ( @@ -21,7 +20,7 @@ from synapse.http.servlet import ( assert_params_in_dict, parse_json_object_from_request, ) -from synapse.rest.admin._base import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.types import UserID logger = logging.getLogger(__name__) @@ -32,14 +31,12 @@ class DeviceRestServlet(RestServlet): Get, update or delete the given user's device """ - PATTERNS = ( - re.compile( - "^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$" - ), + PATTERNS = admin_patterns( + "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" ) def __init__(self, hs): - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet): Retrieve the given user's devices """ - PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") def __init__(self, hs): """ @@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet): key which lists the device_ids to delete. """ - PATTERNS = ( - re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"), - ) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") def __init__(self, hs): self.hs = hs diff --git a/synapse/rest/admin/event_reports.py b/synapse/rest/admin/event_reports.py new file mode 100644
index 0000000000..5b8d0594cd --- /dev/null +++ b/synapse/rest/admin/event_reports.py
@@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Dirk Klimpel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet, parse_integer, parse_string +from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin + +logger = logging.getLogger(__name__) + + +class EventReportsRestServlet(RestServlet): + """ + List all reported events that are known to the homeserver. Results are returned + in a dictionary containing report information. Supports pagination. + The requester must have administrator access in Synapse. + + GET /_synapse/admin/v1/event_reports + returns: + 200 OK with list of reports if success otherwise an error. + + Args: + The parameters `from` and `limit` are required only for pagination. + By default, a `limit` of 100 is used. + The parameter `dir` can be used to define the order of results. + The parameter `user_id` can be used to filter by user id. + The parameter `room_id` can be used to filter by room id. + Returns: + A list of reported events and an integer representing the total number of + reported events that exist given this query + """ + + PATTERNS = admin_patterns("/event_reports$") + + def __init__(self, hs): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request): + await assert_requester_is_admin(self.auth, request) + + start = parse_integer(request, "from", default=0) + limit = parse_integer(request, "limit", default=100) + direction = parse_string(request, "dir", default="b") + user_id = parse_string(request, "user_id") + room_id = parse_string(request, "room_id") + + if start < 0: + raise SynapseError( + 400, + "The start parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if limit < 0: + raise SynapseError( + 400, + "The limit parameter must be a positive integer.", + errcode=Codes.INVALID_PARAM, + ) + + if direction not in ("f", "b"): + raise SynapseError( + 400, "Unknown direction: %s" % (direction,), errcode=Codes.INVALID_PARAM + ) + + event_reports, total = await self.store.get_event_reports_paginate( + start, limit, direction, user_id, room_id + ) + ret = {"event_reports": event_reports, "total": total} + if (start + limit) < total: + ret["next_token"] = start + len(event_reports) + + return 200, ret diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index f474066542..8b7bb6d44e 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,14 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re - from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns class PurgeRoomServlet(RestServlet): @@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet): {} """ - PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),) + PATTERNS = admin_patterns("/purge_room$") def __init__(self, hs): """ diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index b8c95d045a..09726d52d6 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py
@@ -31,7 +31,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, historical_admin_path_patterns, ) -from synapse.storage.data_stores.main.room import RoomSortOrder +from synapse.storage.databases.main.room import RoomSortOrder from synapse.types import RoomAlias, RoomID, UserID, create_requester logger = logging.getLogger(__name__) @@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet): Codes.BAD_JSON, ) + purge = content.get("purge", True) + if not isinstance(purge, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "Param 'purge' must be a boolean, if given", + Codes.BAD_JSON, + ) + ret = await self.room_shutdown_handler.shutdown_room( room_id=room_id, new_room_user_id=content.get("new_room_user_id"), @@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet): ) # Purge room - await self.pagination_handler.purge_room(room_id) + if purge: + await self.pagination_handler.purge_room(room_id) return (200, ret) @@ -307,6 +316,9 @@ class JoinRoomAliasServlet(RestServlet): join_rules_event = room_state.get((EventTypes.JoinRules, "")) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): + # update_membership with an action of "invite" can raise a + # ShadowBanError. This is not handled since it is assumed that + # an admin isn't going to call this API with a shadow-banned user. await self.room_member_handler.update_membership( requester=requester, target=fake_requester.user, diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 6e9a874121..375d055445 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,8 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import re - from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError from synapse.http.servlet import ( @@ -22,6 +20,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) from synapse.rest.admin import assert_requester_is_admin +from synapse.rest.admin._base import admin_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.types import UserID @@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet): self.snm = hs.get_server_notices_manager() def register(self, json_resource): - PATTERN = "^/_synapse/admin/v1/send_server_notice" + PATTERN = "/send_server_notice" json_resource.register_paths( - "POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__ + "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ ) json_resource.register_paths( "PUT", - (re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),), + admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"), self.on_PUT, self.__class__.__name__, ) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index cc0bdfa5c9..20dc1d0e05 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py
@@ -15,7 +15,6 @@ import hashlib import hmac import logging -import re from http import HTTPStatus from synapse.api.constants import UserTypes @@ -29,6 +28,7 @@ from synapse.http.servlet import ( parse_string, ) from synapse.rest.admin._base import ( + admin_patterns, assert_requester_is_admin, assert_user_is_admin, historical_admin_path_patterns, @@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet): class UsersRestServletV2(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),) + PATTERNS = admin_patterns("/users$", "v2") """Get request to list all local users. This needs user to have administrator access in Synapse. @@ -73,6 +73,7 @@ class UsersRestServletV2(RestServlet): The parameters `from` and `limit` are required only for pagination. By default, a `limit` of 100 is used. The parameter `user_id` can be used to filter by user id. + The parameter `name` can be used to filter by user id or display name. The parameter `guests` can be used to exclude guest users. The parameter `deactivated` can be used to include deactivated users. """ @@ -89,11 +90,12 @@ class UsersRestServletV2(RestServlet): start = parse_integer(request, "from", default=0) limit = parse_integer(request, "limit", default=100) user_id = parse_string(request, "user_id", default=None) + name = parse_string(request, "name", default=None) guests = parse_boolean(request, "guests", default=True) deactivated = parse_boolean(request, "deactivated", default=False) users, total = await self.store.get_users_paginate( - start, limit, user_id, guests, deactivated + start, limit, user_id, name, guests, deactivated ) ret = {"users": users, "total": total} if len(users) >= limit: @@ -103,7 +105,7 @@ class UsersRestServletV2(RestServlet): class UserRestServletV2(RestServlet): - PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2") """Get request to list user details. This needs user to have administrator access in Synapse. @@ -640,7 +642,7 @@ class UserAdminServlet(RestServlet): {} """ - PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),) + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$") def __init__(self, hs): self.hs = hs @@ -681,3 +683,29 @@ class UserAdminServlet(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) return 200, {} + + +class UserMembershipRestServlet(RestServlet): + """ + Get room list of an user. + """ + + PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$") + + def __init__(self, hs): + self.is_mine = hs.is_mine + self.auth = hs.get_auth() + self.store = hs.get_datastore() + + async def on_GET(self, request, user_id): + await assert_requester_is_admin(self.auth, request) + + if not self.is_mine(UserID.from_string(user_id)): + raise SynapseError(400, "Can only lookup local users") + + room_ids = await self.store.get_rooms_for_user(user_id) + if not room_ids: + raise NotFoundError("User not found") + + ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} + return 200, ret diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py
index 6da71dc46f..7be5c0fb88 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__) CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins -class HttpTransactionCache(object): +class HttpTransactionCache: def __init__(self, hs): self.hs = hs self.auth = self.hs.get_auth() diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py
index 5934b1fe8b..faabeeb91c 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py
@@ -40,7 +40,7 @@ class ClientDirectoryServer(RestServlet): PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -89,7 +89,7 @@ class ClientDirectoryServer(RestServlet): dir_handler = self.handlers.directory_handler try: - service = await self.auth.get_appservice_by_req(request) + service = self.auth.get_appservice_by_req(request) room_alias = RoomAlias.from_string(room_alias) await dir_handler.delete_appservice_association(service, room_alias) logger.info( @@ -120,7 +120,7 @@ class ClientDirectoryListServer(RestServlet): PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -160,7 +160,7 @@ class ClientAppserviceDirectoryListServer(RestServlet): ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__() + super().__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py
index 25effd0261..1ecb77aa26 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py
@@ -30,9 +30,10 @@ class EventStreamRestServlet(RestServlet): DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__() + super().__init__() self.event_stream_handler = hs.get_event_stream_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet): if b"room_id" in request.args: room_id = request.args[b"room_id"][0].decode("ascii") - pagin_config = PaginationConfig.from_request(request) + pagin_config = await PaginationConfig.from_request(self.store, request) timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS if b"timeout" in request.args: try: @@ -74,7 +75,7 @@ class EventRestServlet(RestServlet): PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 910b3b4eeb..91da0ee573 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py
@@ -24,14 +24,15 @@ class InitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request) as_client_event = b"raw" not in request.args - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) include_archived = parse_boolean(request, "archived", default=False) content = await self.initial_sync_handler.snapshot_all_rooms( user_id=requester.user.to_string(), diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 379f668d6f..3d1693d7ac 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -18,6 +18,11 @@ from typing import Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter +from synapse.appservice import ApplicationService +from synapse.handlers.auth import ( + convert_client_dict_legacy_fields_to_identifier, + login_id_phone_to_thirdparty, +) from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, @@ -28,56 +33,11 @@ from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import JsonDict, UserID -from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email logger = logging.getLogger(__name__) -def login_submission_legacy_convert(submission): - """ - If the input login submission is an old style object - (ie. with top-level user / medium / address) convert it - to a typed object. - """ - if "user" in submission: - submission["identifier"] = {"type": "m.id.user", "user": submission["user"]} - del submission["user"] - - if "medium" in submission and "address" in submission: - submission["identifier"] = { - "type": "m.id.thirdparty", - "medium": submission["medium"], - "address": submission["address"], - } - del submission["medium"] - del submission["address"] - - -def login_id_thirdparty_from_phone(identifier): - """ - Convert a phone login identifier type to a generic threepid identifier - Args: - identifier(dict): Login identifier dict of type 'm.id.phone' - - Returns: Login identifier dict of type 'm.id.threepid' - """ - if "country" not in identifier or ( - # The specification requires a "phone" field, while Synapse used to require a "number" - # field. Accept both for backwards compatibility. - "phone" not in identifier - and "number" not in identifier - ): - raise SynapseError(400, "Invalid phone-type identifier") - - # Accept both "phone" and "number" as valid keys in m.id.phone - phone_number = identifier.get("phone", identifier["number"]) - - msisdn = phone_number_to_msisdn(identifier["country"], phone_number) - - return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn} - - class LoginRestServlet(RestServlet): PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" @@ -85,9 +45,10 @@ class LoginRestServlet(RestServlet): TOKEN_TYPE = "m.login.token" JWT_TYPE = "org.matrix.login.jwt" JWT_TYPE_DEPRECATED = "m.login.jwt" + APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service" def __init__(self, hs): - super(LoginRestServlet, self).__init__() + super().__init__() self.hs = hs # JWT configuration variables. @@ -102,6 +63,8 @@ class LoginRestServlet(RestServlet): self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self.auth = hs.get_auth() + self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.handlers = hs.get_handlers() @@ -157,8 +120,12 @@ class LoginRestServlet(RestServlet): self._address_ratelimiter.ratelimit(request.getClientIP()) login_submission = parse_json_object_from_request(request) + try: - if self.jwt_enabled and ( + if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE: + appservice = self.auth.get_appservice_by_req(request) + result = await self._do_appservice_login(login_submission, appservice) + elif self.jwt_enabled and ( login_submission["type"] == LoginRestServlet.JWT_TYPE or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED ): @@ -175,6 +142,33 @@ class LoginRestServlet(RestServlet): result["well_known"] = well_known_data return 200, result + def _get_qualified_user_id(self, identifier): + if identifier["type"] != "m.id.user": + raise SynapseError(400, "Unknown login identifier type") + if "user" not in identifier: + raise SynapseError(400, "User identifier is missing 'user' key") + + if identifier["user"].startswith("@"): + return identifier["user"] + else: + return UserID(identifier["user"], self.hs.hostname).to_string() + + async def _do_appservice_login( + self, login_submission: JsonDict, appservice: ApplicationService + ): + logger.info( + "Got appservice login request with identifier: %r", + login_submission.get("identifier"), + ) + + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) + qualified_user_id = self._get_qualified_user_id(identifier) + + if not appservice.is_interested_in_user(qualified_user_id): + raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN) + + return await self._complete_login(qualified_user_id, login_submission) + async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]: """Handle non-token/saml/jwt logins @@ -194,18 +188,11 @@ class LoginRestServlet(RestServlet): login_submission.get("address"), login_submission.get("user"), ) - login_submission_legacy_convert(login_submission) - - if "identifier" not in login_submission: - raise SynapseError(400, "Missing param: identifier") - - identifier = login_submission["identifier"] - if "type" not in identifier: - raise SynapseError(400, "Login identifier has no type") + identifier = convert_client_dict_legacy_fields_to_identifier(login_submission) # convert phone type identifiers to generic threepids if identifier["type"] == "m.id.phone": - identifier = login_id_thirdparty_from_phone(identifier) + identifier = login_id_phone_to_thirdparty(identifier) # convert threepid identifiers to user IDs if identifier["type"] == "m.id.thirdparty": @@ -267,15 +254,7 @@ class LoginRestServlet(RestServlet): # by this point, the identifier should be an m.id.user: if it's anything # else, we haven't understood it. - if identifier["type"] != "m.id.user": - raise SynapseError(400, "Unknown login identifier type") - if "user" not in identifier: - raise SynapseError(400, "User identifier is missing 'user' key") - - if identifier["user"].startswith("@"): - qualified_user_id = identifier["user"] - else: - qualified_user_id = UserID(identifier["user"], self.hs.hostname).to_string() + qualified_user_id = self._get_qualified_user_id(identifier) # Check if we've hit the failed ratelimit (but don't update it) self._failed_attempts_ratelimiter.ratelimit( @@ -303,9 +282,7 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[ - Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] - ] = None, + callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to @@ -318,12 +295,12 @@ class LoginRestServlet(RestServlet): Args: user_id: ID of the user to register. login_submission: Dictionary of login information. - callback: Callback function to run after registration. + callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. Returns: - result: Dictionary of account information after successful registration. + result: Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in @@ -358,14 +335,24 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + + Returns: + The body of the JSON response. + """ token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = await self._complete_login(user_id, login_submission) - return result + return await self._complete_login( + user_id, login_submission, self.auth_handler._sso_login_callback + ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) @@ -448,7 +435,7 @@ class CasTicketServlet(RestServlet): PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__() + super().__init__() self._cas_handler = hs.get_cas_handler() async def on_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index b0c30b65be..f792b50cdc 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py
@@ -25,7 +25,7 @@ class LogoutRestServlet(RestServlet): PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -53,7 +53,7 @@ class LogoutAllRestServlet(RestServlet): PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py
index ceaa28c212..4796cdac05 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py
@@ -30,7 +30,7 @@ class PresenceStatusRestServlet(RestServlet): PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__() + super().__init__() self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py
index 165313b572..204b2ec9e5 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py
@@ -26,7 +26,7 @@ class ProfileDisplaynameRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.http_client = hs.get_simple_http_client() @@ -91,7 +91,7 @@ class ProfileAvatarURLRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.http_client = hs.get_simple_http_client() @@ -159,7 +159,7 @@ class ProfileRestServlet(RestServlet): PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__() + super().__init__() self.hs = hs self.profile_handler = hs.get_profile_handler() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9fd4908136..f9eecb7cf5 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py
@@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from synapse.api.errors import ( NotFoundError, StoreError, @@ -25,7 +24,7 @@ from synapse.http.servlet import ( parse_json_value_from_request, parse_string, ) -from synapse.push.baserules import BASE_RULE_IDS +from synapse.push.baserules import BASE_RULE_IDS, NEW_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP from synapse.rest.client.v2_alpha._base import client_patterns @@ -39,12 +38,14 @@ class PushRuleRestServlet(RestServlet): ) def __init__(self, hs): - super(PushRuleRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None + self._users_new_default_push_rules = hs.config.users_new_default_push_rules + async def on_PUT(self, request, path): if self._is_worker: raise Exception("Cannot handle PUT /push_rules on worker") @@ -158,10 +159,22 @@ class PushRuleRestServlet(RestServlet): return 200, {} def notify_user(self, user_id): - stream_id, _ = self.store.get_push_rules_stream_token() + stream_id = self.store.get_max_push_rules_stream_id() self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id]) - def set_rule_attr(self, user_id, spec, val): + async def set_rule_attr(self, user_id, spec, val): + if spec["attr"] not in ("enabled", "actions"): + # for the sake of potential future expansion, shouldn't report + # 404 in the case of an unknown request so check it corresponds to + # a known attribute first. + raise UnrecognizedRequestError() + + namespaced_rule_id = _namespaced_rule_id_from_spec(spec) + rule_id = spec["rule_id"] + is_default_rule = rule_id.startswith(".") + if is_default_rule: + if namespaced_rule_id not in BASE_RULE_IDS: + raise NotFoundError("Unknown rule %s" % (namespaced_rule_id,)) if spec["attr"] == "enabled": if isinstance(val, dict) and "enabled" in val: val = val["enabled"] @@ -170,8 +183,9 @@ class PushRuleRestServlet(RestServlet): # This should *actually* take a dict, but many clients pass # bools directly, so let's not break them. raise SynapseError(400, "Value for 'enabled' must be boolean") - namespaced_rule_id = _namespaced_rule_id_from_spec(spec) - return self.store.set_push_rule_enabled(user_id, namespaced_rule_id, val) + return await self.store.set_push_rule_enabled( + user_id, namespaced_rule_id, val, is_default_rule + ) elif spec["attr"] == "actions": actions = val.get("actions") _check_actions(actions) @@ -179,9 +193,14 @@ class PushRuleRestServlet(RestServlet): rule_id = spec["rule_id"] is_default_rule = rule_id.startswith(".") if is_default_rule: - if namespaced_rule_id not in BASE_RULE_IDS: + if user_id in self._users_new_default_push_rules: + rule_ids = NEW_RULE_IDS + else: + rule_ids = BASE_RULE_IDS + + if namespaced_rule_id not in rule_ids: raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,)) - return self.store.set_push_rule_actions( + return await self.store.set_push_rule_actions( user_id, namespaced_rule_id, actions, is_default_rule ) else: diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index 5f65cb7d83..28dabf1c7a 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py
@@ -44,7 +44,7 @@ class PushersRestServlet(RestServlet): PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -68,7 +68,7 @@ class PushersSetRestServlet(RestServlet): PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.notifier = hs.get_notifier() @@ -153,7 +153,7 @@ class PushersRemoveRestServlet(RestServlet): SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" def __init__(self, hs): - super(PushersRemoveRestServlet, self).__init__() + super().__init__() self.hs = hs self.notifier = hs.get_notifier() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 1a3398316d..b421fe855e 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py
@@ -21,14 +21,13 @@ import re from typing import List, Optional from urllib import parse as urlparse -from canonicaljson import json - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, HttpResponseException, InvalidClientCredentialsError, + ShadowBanError, SynapseError, ) from synapse.api.filtering import Filter @@ -46,6 +45,8 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder +from synapse.util.stringutils import random_string MYPY = False if MYPY: @@ -56,7 +57,7 @@ logger = logging.getLogger(__name__) class TransactionRestServlet(RestServlet): def __init__(self, hs): - super(TransactionRestServlet, self).__init__() + super().__init__() self.txns = HttpTransactionCache(hs) @@ -64,7 +65,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): - super(RoomCreateRestServlet, self).__init__(hs) + super().__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() self.auth = hs.get_auth() @@ -110,7 +111,7 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomStateEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() @@ -170,7 +171,6 @@ class RoomStateEventRestServlet(TransactionRestServlet): room_id=room_id, event_type=event_type, state_key=state_key, - is_guest=requester.is_guest, ) if not data: @@ -200,23 +200,26 @@ class RoomStateEventRestServlet(TransactionRestServlet): if state_key is not None: event_dict["state_key"] = state_key - if event_type == EventTypes.Member: - membership = content.get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - requester, - target=UserID.from_string(state_key), - room_id=room_id, - action=membership, - content=content, - ) - else: - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) - event_id = event.event_id + try: + if event_type == EventTypes.Member: + membership = content.get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + requester, + target=UserID.from_string(state_key), + room_id=room_id, + action=membership, + content=content, + ) + else: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) set_tag("event_id", event_id) ret = {"event_id": event_id} @@ -226,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events + feedback class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomSendEventRestServlet, self).__init__(hs) + super().__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -249,12 +252,19 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_GET(self, request, room_id, event_type, txn_id): return 200, "Not implemented" @@ -270,7 +280,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): # TODO: Needs unit testing for room ID + alias joins class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): - super(JoinRoomAliasServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -333,7 +343,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs): - super(PublicRoomListRestServlet, self).__init__(hs) + super().__init__(hs) self.hs = hs self.auth = hs.get_auth() @@ -438,13 +448,14 @@ class RoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -455,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet): if at_token_string is None: at_token = None else: - at_token = StreamToken.from_string(at_token_string) + at_token = await StreamToken.from_string(self.store, at_token_string) # let you filter down on particular memberships. # XXX: this may not be the best shape for this API - we could pass in a filter @@ -489,7 +500,7 @@ class JoinedRoomMemberListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -508,18 +519,23 @@ class RoomMessageListRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__() + super().__init__() self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request, default_limit=10) + pagination_config = await PaginationConfig.from_request( + self.store, request, default_limit=10 + ) as_client_event = b"raw" not in request.args filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -545,7 +561,7 @@ class RoomStateRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__() + super().__init__() self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() @@ -565,13 +581,14 @@ class RoomInitialSyncRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__() + super().__init__() self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() + self.store = hs.get_datastore() async def on_GET(self, request, room_id): requester = await self.auth.get_user_by_req(request, allow_guest=True) - pagination_config = PaginationConfig.from_request(request) + pagination_config = await PaginationConfig.from_request(self.store, request) content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) @@ -584,7 +601,7 @@ class RoomEventServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() @@ -616,7 +633,7 @@ class RoomEventContextServlet(RestServlet): ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__() + super().__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() @@ -631,7 +648,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] else: event_filter = None @@ -661,7 +680,7 @@ class RoomEventContextServlet(RestServlet): class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomForgetRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -687,7 +706,7 @@ class RoomForgetRestServlet(TransactionRestServlet): # TODO: Needs unit testing class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomMembershipRestServlet, self).__init__(hs) + super().__init__(hs) self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() @@ -716,17 +735,21 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - await self.room_member_handler.do_3pid_invite( - room_id, - requester.user, - content["medium"], - content["address"], - content["id_server"], - requester, - txn_id, - new_room=False, - id_access_token=content.get("id_access_token"), - ) + try: + await self.room_member_handler.do_3pid_invite( + room_id, + requester.user, + content["medium"], + content["address"], + content["id_server"], + requester, + txn_id, + new_room=False, + id_access_token=content.get("id_access_token"), + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return 200, {} target = requester.user @@ -738,15 +761,19 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content: event_content = {"reason": content["reason"]} - await self.room_member_handler.update_membership( - requester=requester, - target=target, - room_id=room_id, - action=membership_action, - txn_id=txn_id, - third_party_signed=content.get("third_party_signed", None), - content=event_content, - ) + try: + await self.room_member_handler.update_membership( + requester=requester, + target=target, + room_id=room_id, + action=membership_action, + txn_id=txn_id, + third_party_signed=content.get("third_party_signed", None), + content=event_content, + ) + except ShadowBanError: + # Pretend the request succeeded. + pass return_value = {} @@ -771,7 +798,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): - super(RoomRedactEventRestServlet, self).__init__(hs) + super().__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() @@ -784,20 +811,27 @@ class RoomRedactEventRestServlet(TransactionRestServlet): requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, - { - "type": EventTypes.Redaction, - "content": content, - "room_id": room_id, - "sender": requester.user.to_string(), - "redacts": event_id, - }, - txn_id=txn_id, - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, + { + "type": EventTypes.Redaction, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + "redacts": event_id, + }, + txn_id=txn_id, + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - set_tag("event_id", event.event_id) - return 200, {"event_id": event.event_id} + set_tag("event_id", event_id) + return 200, {"event_id": event_id} def on_PUT(self, request, room_id, event_id, txn_id): set_tag("txn_id", txn_id) @@ -813,7 +847,7 @@ class RoomTypingRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__() + super().__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() @@ -840,17 +874,21 @@ class RoomTypingRestServlet(RestServlet): # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) - if content["typing"]: - await self.typing_handler.started_typing( - target_user=target_user, - auth_user=requester.user, - room_id=room_id, - timeout=timeout, - ) - else: - await self.typing_handler.stopped_typing( - target_user=target_user, auth_user=requester.user, room_id=room_id - ) + try: + if content["typing"]: + await self.typing_handler.started_typing( + target_user=target_user, + requester=requester, + room_id=room_id, + timeout=timeout, + ) + else: + await self.typing_handler.stopped_typing( + target_user=target_user, requester=requester, room_id=room_id + ) + except ShadowBanError: + # Pretend this worked without error. + pass return 200, {} @@ -882,7 +920,7 @@ class SearchRestServlet(RestServlet): PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__() + super().__init__() self.handlers = hs.get_handlers() self.auth = hs.get_auth() @@ -903,7 +941,7 @@ class JoinedRoomsRestServlet(RestServlet): PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 50277c6cf6..b8d491ca5c 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py
@@ -25,7 +25,7 @@ class VoipRestServlet(RestServlet): PATTERNS = client_patterns("/voip/turnServer$", v1=True) def __init__(self, hs): - super(VoipRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index d4b1ee1e8c..1320aad8f6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py
@@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import random import re from http import HTTPStatus from typing import TYPE_CHECKING @@ -23,10 +24,13 @@ from urllib.parse import urlparse if TYPE_CHECKING: from synapse.app.homeserver import HomeServer -from twisted.internet import defer - from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + InteractiveAuthIncompleteError, + SynapseError, + ThreepidValidationError, +) from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.server import finish_request, respond_with_html from synapse.http.servlet import ( @@ -35,7 +39,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.types import UserID from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -50,28 +54,18 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): - super(EmailPasswordRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.datastore = hs.get_datastore() self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -104,13 +98,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): send_attempt = body["send_attempt"] next_link = body.get("next_link") # Optional param - if not check_3pid_allowed(self.hs, "email", email): - raise SynapseError( - 403, - "Your email is not authorized on this server", - Codes.THREEPID_DENIED, - ) - if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -127,6 +114,9 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) @@ -158,87 +148,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): return 200, ret -class PasswordResetSubmitTokenServlet(RestServlet): - """Handles 3PID validation token submission""" - - PATTERNS = client_patterns( - "/password_reset/(?P<medium>[^/]*)/submit_token$", releases=(), unstable=True - ) - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ - super(PasswordResetSubmitTokenServlet, self).__init__() - self.hs = hs - self.auth = hs.get_auth() - self.config = hs.config - self.clock = hs.get_clock() - self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], - ) - - async def on_GET(self, request, medium): - # We currently only handle threepid token submissions for email - if medium != "email": - raise SynapseError( - 400, "This medium is currently not supported for password resets" - ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: - logger.warning( - "Password reset emails have been disabled due to lack of an email config" - ) - raise SynapseError( - 400, "Email-based password resets are disabled on this server" - ) - - sid = parse_string(request, "sid", required=True) - token = parse_string(request, "token", required=True) - client_secret = parse_string(request, "client_secret", required=True) - assert_valid_client_secret(client_secret) - - # Attempt to validate a 3PID session - try: - # Mark the session as valid - next_link = await self.store.validate_threepid_session( - sid, client_secret, token, self.clock.time_msec() - ) - - # Perform a 302 redirect if next_link is set - if next_link: - if next_link.startswith("file:///"): - logger.warning( - "Not redirecting to next_link as it is a local file: address" - ) - else: - request.setResponseCode(302) - request.setHeader("Location", next_link) - finish_request(request) - return None - - # Otherwise show the success template - html = self.config.email_password_reset_template_success_html - status_code = 200 - except ThreepidValidationError as e: - status_code = e.code - - # Show a failure page with a reason - template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) - - respond_with_html(request, status_code, html) - - class PasswordRestServlet(RestServlet): PATTERNS = client_patterns("/account/password$") def __init__(self, hs): - super(PasswordRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -253,18 +167,12 @@ class PasswordRestServlet(RestServlet): # we do basic sanity checks here because the auth layer will store these # in sessions. Pull out the new password provided to us. - if "new_password" in body: - new_password = body.pop("new_password") + new_password = body.pop("new_password", None) + if new_password is not None: if not isinstance(new_password, str) or len(new_password) > 512: raise SynapseError(400, "Invalid password") self.password_policy_handler.validate_password(new_password) - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "new_password_hash" in body: - raise SynapseError(400, "Unexpected property: new_password_hash") - body["new_password_hash"] = await self.auth_handler.hash(new_password) - # there are two possibilities here. Either the user does not have an # access token, and needs to do a password reset; or they have one and # need to validate their identity. @@ -281,23 +189,52 @@ class PasswordRestServlet(RestServlet): if requester.app_service: params = body else: - params = await self.auth_handler.validate_user_via_ui_auth( - requester, + try: + ( + params, + session_id, + ) = await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + self.hs.get_ip_from_request(request), + "modify your account password", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise + user_id = requester.user.to_string() + else: + requester = None + try: + result, params, session_id = await self.auth_handler.check_ui_auth( + [[LoginType.EMAIL_IDENTITY]], request, body, self.hs.get_ip_from_request(request), "modify your account password", ) - user_id = requester.user.to_string() - else: - requester = None - result, params, _ = await self.auth_handler.check_auth( - [[LoginType.EMAIL_IDENTITY]], - request, - body, - self.hs.get_ip_from_request(request), - "modify your account password", - ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth, but + # they're not required to provide the password again. + # + # If a password is available now, hash the provided password and + # store it for later. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise if LoginType.EMAIL_IDENTITY in result: threepid = result[LoginType.EMAIL_IDENTITY] @@ -322,32 +259,40 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - assert_params_in_dict(params, ["new_password_hash"]) - new_password_hash = params["new_password_hash"] + # If we have a password in this request, prefer it. Otherwise, there + # must be a password hash from an earlier request. + if new_password: + password_hash = await self.auth_handler.hash(new_password) + else: + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) + logout_devices = params.get("logout_devices", True) await self._set_password_handler.set_password( - user_id, new_password_hash, logout_devices, requester + user_id, password_hash, logout_devices, requester ) if self.hs.config.shadow_server: shadow_user = UserID( requester.user.localpart, self.hs.config.shadow_server.get("hs") ) - self.shadow_password(params, shadow_user.to_string()) + await self.shadow_password(params, shadow_user.to_string()) return 200, {} def on_OPTIONS(self, _): return 200, {} - @defer.inlineCallbacks - def shadow_password(self, body, user_id): + async def shadow_password(self, body, user_id): # TODO: retries shadow_hs_url = self.hs.config.shadow_server.get("hs_url") as_token = self.hs.config.shadow_server.get("as_token") - yield self.http_client.post_json_get_json( + await self.http_client.post_json_get_json( "%s/_matrix/client/r0/account/password?access_token=%s&user_id=%s" % (shadow_hs_url, as_token, user_id), body, @@ -358,7 +303,7 @@ class DeactivateAccountRestServlet(RestServlet): PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): - super(DeactivateAccountRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -377,7 +322,7 @@ class DeactivateAccountRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request) - # allow ASes to dectivate their own users + # allow ASes to deactivate their own users if requester.app_service: await self._deactivate_account_handler.deactivate_account( requester.user.to_string(), erase @@ -406,26 +351,18 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): - super(EmailThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -473,6 +410,9 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): if self.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -509,7 +449,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): def __init__(self, hs): self.hs = hs - super(MsisdnThreepidRequestTokenRestServlet, self).__init__() + super().__init__() self.store = self.hs.get_datastore() self.identity_handler = hs.get_handlers().identity_handler @@ -545,6 +485,9 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) @@ -588,9 +531,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -636,7 +578,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -687,7 +629,7 @@ class ThreepidRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): - super(ThreepidRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -731,7 +673,7 @@ class ThreepidRestServlet(RestServlet): shadow_user = UserID( requester.user.localpart, self.hs.config.shadow_server.get("hs") ) - self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} @@ -766,7 +708,7 @@ class ThreepidRestServlet(RestServlet): "address": validation_session["address"], "validated_at": validation_session["validated_at"], } - self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} @@ -774,13 +716,12 @@ class ThreepidRestServlet(RestServlet): 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) - @defer.inlineCallbacks - def shadow_3pid(self, body, user_id): + async def shadow_3pid(self, body, user_id): # TODO: retries shadow_hs_url = self.hs.config.shadow_server.get("hs_url") as_token = self.hs.config.shadow_server.get("as_token") - yield self.http_client.post_json_get_json( + await self.http_client.post_json_get_json( "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % (shadow_hs_url, as_token, user_id), body, @@ -791,7 +732,7 @@ class ThreepidAddRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/add$") def __init__(self, hs): - super(ThreepidAddRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -841,20 +782,19 @@ class ThreepidAddRestServlet(RestServlet): "address": validation_session["address"], "validated_at": validation_session["validated_at"], } - self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) + await self.shadow_3pid({"threepid": threepid}, shadow_user.to_string()) return 200, {} raise SynapseError( 400, "No validated 3pid session found", Codes.THREEPID_AUTH_FAILED ) - @defer.inlineCallbacks - def shadow_3pid(self, body, user_id): + async def shadow_3pid(self, body, user_id): # TODO: retries shadow_hs_url = self.hs.config.shadow_server.get("hs_url") as_token = self.hs.config.shadow_server.get("as_token") - yield self.http_client.post_json_get_json( + await self.http_client.post_json_get_json( "%s/_matrix/client/r0/account/3pid?access_token=%s&user_id=%s" % (shadow_hs_url, as_token, user_id), body, @@ -865,7 +805,7 @@ class ThreepidBindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/bind$") def __init__(self, hs): - super(ThreepidBindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -894,7 +834,7 @@ class ThreepidUnbindRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/unbind$") def __init__(self, hs): - super(ThreepidUnbindRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.auth = hs.get_auth() @@ -925,7 +865,7 @@ class ThreepidDeleteRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): - super(ThreepidDeleteRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -958,7 +898,7 @@ class ThreepidDeleteRestServlet(RestServlet): shadow_user = UserID( requester.user.localpart, self.hs.config.shadow_server.get("hs") ) - self.shadow_3pid_delete(body, shadow_user.to_string()) + await self.shadow_3pid_delete(body, shadow_user.to_string()) if ret: id_server_unbind_result = "success" @@ -967,13 +907,12 @@ class ThreepidDeleteRestServlet(RestServlet): return 200, {"id_server_unbind_result": id_server_unbind_result} - @defer.inlineCallbacks - def shadow_3pid_delete(self, body, user_id): + async def shadow_3pid_delete(self, body, user_id): # TODO: retries shadow_hs_url = self.hs.config.shadow_server.get("hs_url") as_token = self.hs.config.shadow_server.get("as_token") - yield self.http_client.post_json_get_json( + await self.http_client.post_json_get_json( "%s/_matrix/client/r0/account/3pid/delete?access_token=%s&user_id=%s" % (shadow_hs_url, as_token, user_id), body, @@ -988,12 +927,11 @@ class ThreepidLookupRestServlet(RestServlet): self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): """Proxy a /_matrix/identity/api/v1/lookup request to an identity server """ - yield self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request) # Verify query parameters query_params = request.args @@ -1006,9 +944,9 @@ class ThreepidLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address) + ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address) - defer.returnValue((200, ret)) + return 200, ret class ThreepidBulkLookupRestServlet(RestServlet): @@ -1019,12 +957,11 @@ class ThreepidBulkLookupRestServlet(RestServlet): self.auth = hs.get_auth() self.identity_handler = hs.get_handlers().identity_handler - @defer.inlineCallbacks - def on_POST(self, request): + async def on_POST(self, request): """Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity server """ - yield self.auth.get_user_by_req(request) + await self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) @@ -1032,11 +969,11 @@ class ThreepidBulkLookupRestServlet(RestServlet): # Proxy the request to the identity server. lookup_3pid handles checking # if the lookup is allowed so we don't need to do it here. - ret = yield self.identity_handler.proxy_bulk_lookup_3pid( + ret = await self.identity_handler.proxy_bulk_lookup_3pid( body["id_server"], body["threepids"] ) - defer.returnValue((200, ret)) + return 200, ret def assert_valid_next_link(hs: "HomeServer", next_link: str): @@ -1082,7 +1019,7 @@ class WhoamiRestServlet(RestServlet): PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): - super(WhoamiRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() async def on_GET(self, request): @@ -1093,7 +1030,6 @@ class WhoamiRestServlet(RestServlet): def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) - PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index d31ec7c29d..617ee6d62a 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -35,7 +35,7 @@ class AccountDataServlet(RestServlet): ) def __init__(self, hs): - super(AccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() @@ -93,7 +93,7 @@ class RoomAccountDataServlet(RestServlet): ) def __init__(self, hs): - super(RoomAccountDataServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py
index d06336ceea..bd7f9ae203 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py
@@ -32,7 +32,7 @@ class AccountValidityRenewServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValidityRenewServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() @@ -67,7 +67,7 @@ class AccountValiditySendMailServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(AccountValiditySendMailServlet, self).__init__() + super().__init__() self.hs = hs self.account_activity_handler = hs.get_account_validity_handler() diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 8e585e9153..5fbfae5991 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py
@@ -25,94 +25,6 @@ from ._base import client_patterns logger = logging.getLogger(__name__) -RECAPTCHA_TEMPLATE = """ -<html> -<head> -<title>Authentication</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<script src="https://www.recaptcha.net/recaptcha/api.js" - async defer></script> -<script src="//code.jquery.com/jquery-1.11.2.min.js"></script> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -<script> -function captchaDone() { - $('#registrationForm').submit(); -} -</script> -</head> -<body> -<form id="registrationForm" method="post" action="%(myurl)s"> - <div> - <p> - Hello! We need to prevent computer programs and other automated - things from creating accounts on this server. - </p> - <p> - Please verify that you're not a robot. - </p> - <input type="hidden" name="session" value="%(session)s" /> - <div class="g-recaptcha" - data-sitekey="%(sitekey)s" - data-callback="captchaDone"> - </div> - <noscript> - <input type="submit" value="All Done" /> - </noscript> - </div> - </div> -</form> -</body> -</html> -""" - -TERMS_TEMPLATE = """ -<html> -<head> -<title>Authentication</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -</head> -<body> -<form id="registrationForm" method="post" action="%(myurl)s"> - <div> - <p> - Please click the button below if you agree to the - <a href="%(terms_url)s">privacy policy of this homeserver.</a> - </p> - <input type="hidden" name="session" value="%(session)s" /> - <input type="submit" value="Agree" /> - </div> -</form> -</body> -</html> -""" - -SUCCESS_TEMPLATE = """ -<html> -<head> -<title>Success!</title> -<meta name='viewport' content='width=device-width, initial-scale=1, - user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'> -<link rel="stylesheet" href="/_matrix/static/client/register/style.css"> -<script> -if (window.onAuthDone) { - window.onAuthDone(); -} else if (window.opener && window.opener.postMessage) { - window.opener.postMessage("authDone", "*"); -} -</script> -</head> -<body> - <div> - <p>Thank you</p> - <p>You may now close this window and return to the application</p> - </div> -</body> -</html> -""" - class AuthRestServlet(RestServlet): """ @@ -124,7 +36,7 @@ class AuthRestServlet(RestServlet): PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): - super(AuthRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() @@ -145,26 +57,30 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url + self.recaptcha_template = hs.config.recaptcha_template + self.terms_template = hs.config.terms_template + self.success_template = hs.config.fallback_success_template + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") if stagetype == LoginType.RECAPTCHA: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % (self.hs.config.public_baseurl, self.hs.config.user_consent_version), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # Display a confirmation page which prompts the user to @@ -222,14 +138,14 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = RECAPTCHA_TEMPLATE % { - "session": session, - "myurl": "%s/r0/auth/%s/fallback/web" + html = self.recaptcha_template.render( + session=session, + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - "sitekey": self.hs.config.recaptcha_public_key, - } + sitekey=self.hs.config.recaptcha_public_key, + ) elif stagetype == LoginType.TERMS: authdict = {"session": session} @@ -238,18 +154,18 @@ class AuthRestServlet(RestServlet): ) if success: - html = SUCCESS_TEMPLATE + html = self.success_template.render() else: - html = TERMS_TEMPLATE % { - "session": session, - "terms_url": "%s_matrix/consent?v=%s" + html = self.terms_template.render( + session=session, + terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.public_baseurl, self.hs.config.user_consent_version, ), - "myurl": "%s/r0/auth/%s/fallback/web" + myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), - } + ) elif stagetype == LoginType.SSO: # The SSO fallback workflow should not post here, raise SynapseError(404, "Fallback SSO auth does not support POST requests.") diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py
index fe9d019c44..76879ac559 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py
@@ -32,7 +32,7 @@ class CapabilitiesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(CapabilitiesRestServlet, self).__init__() + super().__init__() self.hs = hs self.config = hs.config self.auth = hs.get_auth() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py
index c0714fcfb1..7e174de692 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py
@@ -35,7 +35,7 @@ class DevicesRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -57,7 +57,7 @@ class DeleteDevicesRestServlet(RestServlet): PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): - super(DeleteDevicesRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() @@ -102,7 +102,7 @@ class DeviceRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(DeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py
index b28da017cd..7cc692643b 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py
@@ -28,7 +28,7 @@ class GetFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") def __init__(self, hs): - super(GetFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() @@ -64,7 +64,7 @@ class CreateFilterRestServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter") def __init__(self, hs): - super(CreateFilterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.filtering = hs.get_filtering() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py
index d84a6d7e11..a3bb095c2d 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py
@@ -16,6 +16,7 @@ import logging +from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID @@ -31,7 +32,7 @@ class GroupServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") def __init__(self, hs): - super(GroupServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -65,7 +66,7 @@ class GroupSummaryServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") def __init__(self, hs): - super(GroupSummaryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -96,7 +97,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryRoomsCatServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -136,7 +137,7 @@ class GroupCategoryServlet(RestServlet): ) def __init__(self, hs): - super(GroupCategoryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -180,7 +181,7 @@ class GroupCategoriesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/categories/$") def __init__(self, hs): - super(GroupCategoriesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -203,7 +204,7 @@ class GroupRoleServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$") def __init__(self, hs): - super(GroupRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -247,7 +248,7 @@ class GroupRolesServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/roles/$") def __init__(self, hs): - super(GroupRolesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -278,7 +279,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): ) def __init__(self, hs): - super(GroupSummaryUsersRoleServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -316,7 +317,7 @@ class GroupRoomServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") def __init__(self, hs): - super(GroupRoomServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -325,6 +326,9 @@ class GroupRoomServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=True) requester_user_id = requester.user.to_string() + if not GroupID.is_valid(group_id): + raise SynapseError(400, "%s was not legal group ID" % (group_id,)) + result = await self.groups_handler.get_rooms_in_group( group_id, requester_user_id ) @@ -339,7 +343,7 @@ class GroupUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") def __init__(self, hs): - super(GroupUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -362,7 +366,7 @@ class GroupInvitedUsersServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") def __init__(self, hs): - super(GroupInvitedUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -385,7 +389,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") def __init__(self, hs): - super(GroupSettingJoinPolicyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.groups_handler = hs.get_groups_local_handler() @@ -409,7 +413,7 @@ class GroupCreateServlet(RestServlet): PATTERNS = client_patterns("/create_group$") def __init__(self, hs): - super(GroupCreateServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -440,7 +444,7 @@ class GroupAdminRoomsServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -477,7 +481,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminRoomsConfigServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -503,7 +507,7 @@ class GroupAdminUsersInviteServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -532,7 +536,7 @@ class GroupAdminUsersKickServlet(RestServlet): ) def __init__(self, hs): - super(GroupAdminUsersKickServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -556,7 +560,7 @@ class GroupSelfLeaveServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/leave$") def __init__(self, hs): - super(GroupSelfLeaveServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -580,7 +584,7 @@ class GroupSelfJoinServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/join$") def __init__(self, hs): - super(GroupSelfJoinServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -604,7 +608,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/accept_invite$") def __init__(self, hs): - super(GroupSelfAcceptInviteServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() @@ -628,7 +632,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/self/update_publicity$") def __init__(self, hs): - super(GroupSelfUpdatePublicityServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -651,7 +655,7 @@ class PublicisedGroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups/(?P<user_id>[^/]*)$") def __init__(self, hs): - super(PublicisedGroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -672,7 +676,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): PATTERNS = client_patterns("/publicised_groups$") def __init__(self, hs): - super(PublicisedGroupsForUsersServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.store = hs.get_datastore() @@ -696,7 +700,7 @@ class GroupsForUserServlet(RestServlet): PATTERNS = client_patterns("/joined_groups$") def __init__(self, hs): - super(GroupsForUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.clock = hs.get_clock() self.groups_handler = hs.get_groups_local_handler() diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index 24bb090822..55c4606569 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py
@@ -64,7 +64,7 @@ class KeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(KeyUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -147,7 +147,7 @@ class KeyQueryServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyQueryServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -177,9 +177,10 @@ class KeyChangesServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - super(KeyChangesServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.device_handler = hs.get_device_handler() + self.store = hs.get_datastore() async def on_GET(self, request): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet): # changes after the "to" as well as before. set_tag("to", parse_string(request, "to")) - from_token = StreamToken.from_string(from_token_string) + from_token = await StreamToken.from_string(self.store, from_token_string) user_id = requester.user.to_string() @@ -222,7 +223,7 @@ class OneTimeKeyServlet(RestServlet): PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): - super(OneTimeKeyServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -250,7 +251,7 @@ class SigningKeyUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SigningKeyUploadServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() @@ -308,7 +309,7 @@ class SignaturesUploadServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SignaturesUploadServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_keys_handler = hs.get_e2e_keys_handler() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py
index aa911d75ee..87063ec8b1 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py
@@ -27,7 +27,7 @@ class NotificationsServlet(RestServlet): PATTERNS = client_patterns("/notifications$") def __init__(self, hs): - super(NotificationsServlet, self).__init__() + super().__init__() self.store = hs.get_datastore() self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py
index 6ae9a5a8e9..5b996e2d63 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py
@@ -60,7 +60,7 @@ class IdTokenServlet(RestServlet): EXPIRES_MS = 3600 * 1000 def __init__(self, hs): - super(IdTokenServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/password_policy.py b/synapse/rest/client/v2_alpha/password_policy.py
index 968403cca4..68b27ff23a 100644 --- a/synapse/rest/client/v2_alpha/password_policy.py +++ b/synapse/rest/client/v2_alpha/password_policy.py
@@ -30,7 +30,7 @@ class PasswordPolicyServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(PasswordPolicyServlet, self).__init__() + super().__init__() self.policy = hs.config.password_policy self.enabled = hs.config.password_policy_enabled diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py
index 67cbc37312..55c6688f52 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py
@@ -26,7 +26,7 @@ class ReadMarkerRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") def __init__(self, hs): - super(ReadMarkerRestServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
index 92555bd4a9..6f7246a394 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -31,7 +31,7 @@ class ReceiptRestServlet(RestServlet): ) def __init__(self, hs): - super(ReceiptRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 001f49fb3e..91ea76bc20 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py
@@ -17,6 +17,7 @@ import hmac import logging +import random import re from typing import List, Union @@ -26,6 +27,7 @@ import synapse.types from synapse.api.constants import LoginType from synapse.api.errors import ( Codes, + InteractiveAuthIncompleteError, SynapseError, ThreepidValidationError, UnrecognizedRequestError, @@ -45,7 +47,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -76,29 +78,17 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(EmailRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -144,6 +134,9 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) @@ -183,7 +176,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(MsisdnRegisterRequestTokenRestServlet, self).__init__() + super().__init__() self.hs = hs self.identity_handler = hs.get_handlers().identity_handler @@ -218,6 +211,9 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): if self.hs.config.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. + # Also wait for some random amount of time between 100ms and 1s to make it + # look like we did something. + await self.hs.clock.sleep(random.randint(1, 10) / 10) return 200, {"sid": random_string(16)} raise SynapseError( @@ -257,7 +253,7 @@ class RegistrationSubmitTokenServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegistrationSubmitTokenServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.config = hs.config @@ -265,15 +261,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -321,7 +310,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -334,7 +323,7 @@ class UsernameAvailabilityRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UsernameAvailabilityRestServlet, self).__init__() + super().__init__() self.hs = hs self.registration_handler = hs.get_registration_handler() self.ratelimiter = FederationRateLimiter( @@ -372,7 +361,7 @@ class RegisterRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RegisterRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -385,6 +374,7 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() + self._registration_enabled = self.hs.config.enable_registration self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -410,22 +400,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind.decode("utf8"),) ) - # we do basic sanity checks here because the auth layer will store these - # in sessions. Pull out the username/password provided to us. - desired_password_hash = None - if "password" in body: - password = body.pop("password") - if not isinstance(password, str) or len(password) > 512: - raise SynapseError(400, "Invalid password") - self.password_policy_handler.validate_password(password) - - # If the password is valid, hash it and store it back on the body. - # This ensures that only the hashed password is handled everywhere. - if "password_hash" in body: - raise SynapseError(400, "Unexpected property: password_hash") - body["password_hash"] = await self.auth_handler.hash(password) - desired_password_hash = body["password_hash"] - # We don't care about usernames for this deployment. In fact, the act # of checking whether they exist already can leak metadata about # which users are already registered. @@ -440,7 +414,12 @@ class RegisterRestServlet(RestServlet): appservice = None if self.auth.has_access_token(request): - appservice = await self.auth.get_appservice_by_req(request) + appservice = self.auth.get_appservice_by_req(request) + + # We need to retrieve the password early in order to pass it to + # application service registration + # This is specific to shadow server registration of users via an AS + password = body.pop("password", None) # fork off as soon as possible for ASes which have completely # different registration flows to normal users @@ -459,23 +438,33 @@ class RegisterRestServlet(RestServlet): access_token = self.auth.get_access_token_from_request(request) - if isinstance(desired_username, str): - result = await self._do_appservice_registration( - desired_username, - desired_password_hash, - desired_display_name, - access_token, - body, - ) - return 200, result # we throw for non 200 responses + if not isinstance(desired_username, str): + raise SynapseError(400, "Desired Username is missing or not a string") + + result = await self._do_appservice_registration( + desired_username, password, desired_display_name, access_token, body + ) + + return 200, result # == Normal User Registration == (everyone else) - if not self.hs.config.enable_registration: + if not self._registration_enabled: raise SynapseError(403, "Registration has been disabled") + # Check if this account is upgrading from a guest account. guest_access_token = body.get("guest_access_token", None) - if "initial_device_display_name" in body and "password_hash" not in body: + # Pull out the provided password and do basic sanity checks early. + # + # Note that we remove the password from the body since the auth layer + # will store the body in the session and we don't want a plaintext + # password store there. + if password is not None: + if not isinstance(password, str) or len(password) > 512: + raise SynapseError(400, "Invalid password") + self.password_policy_handler.validate_password(password) + + if "initial_device_display_name" in body and password is None: # ignore 'initial_device_display_name' if sent without # a password to work around a client bug where it sent # the 'initial_device_display_name' param alone, wiping out @@ -485,6 +474,7 @@ class RegisterRestServlet(RestServlet): session_id = self.auth_handler.get_session_id(body) registered_user_id = None + password_hash = None if session_id: # if we get a registered user id out of here, it means we previously # registered a user for this session, so we could just return the @@ -493,21 +483,43 @@ class RegisterRestServlet(RestServlet): registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) + # Extract the previously-hashed password from the session. + password_hash = await self.auth_handler.get_session_data( + session_id, "password_hash", None + ) - auth_result, params, session_id = await self.auth_handler.check_auth( - self._registration_flows, - request, - body, - self.hs.get_ip_from_request(request), - "register a new account", - ) + # Check if the user-interactive authentication flows are complete, if + # not this will raise a user-interactive auth error. + try: + auth_result, params, session_id = await self.auth_handler.check_ui_auth( + self._registration_flows, + request, + body, + self.hs.get_ip_from_request(request), + "register a new account", + ) + except InteractiveAuthIncompleteError as e: + # The user needs to provide more steps to complete auth. + # + # Hash the password and store it with the session since the client + # is not required to provide the password again. + # + # If a password hash was previously stored we will not attempt to + # re-hash and store it for efficiency. This assumes the password + # does not change throughout the authentication flow, but this + # should be fine since the data is meant to be consistent. + if not password_hash and password: + password_hash = await self.auth_handler.hash(password) + await self.auth_handler.set_session_data( + e.session_id, "password_hash", password_hash + ) + raise # Check that we're not trying to register a denied 3pid. # # the user-facing checks will probably already have happened in # /register/email/requestToken when we requested a 3pid, but that's not # guaranteed. - if auth_result: for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]: if login_type in auth_result: @@ -603,8 +615,12 @@ class RegisterRestServlet(RestServlet): # don't re-register the threepids registered = False else: - # NB: This may be from the auth handler and NOT from the POST - assert_params_in_dict(params, ["password_hash"]) + # If we have a password in this request, prefer it. Otherwise, there + # might be a password hash from an earlier request. + if password: + password_hash = await self.auth_handler.hash(password) + if not password_hash: + raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) if not self.hs.config.register_mxid_from_3pid: desired_username = params.get("username", None) @@ -613,7 +629,6 @@ class RegisterRestServlet(RestServlet): pass guest_access_token = params.get("guest_access_token", None) - new_password_hash = params.get("password_hash", None) if desired_username is not None: desired_username = desired_username.lower() @@ -653,13 +668,18 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, - password_hash=new_password_hash, + password_hash=password_hash, guest_access_token=guest_access_token, default_display_name=desired_display_name, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db @@ -677,8 +697,8 @@ class RegisterRestServlet(RestServlet): params=params, ) - # remember that we've now registered that user account, and with - # what user ID (since the user may not have specified) + # Remember that the user account has been registered (and the user + # ID it was registered with, since it might not have been specified). await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) @@ -702,12 +722,20 @@ class RegisterRestServlet(RestServlet): return 200, {} async def _do_appservice_registration( - self, username, password_hash, display_name, as_token, body + self, username, password, display_name, as_token, body ): # FIXME: appservice_register() is horribly duplicated with register() # and they should probably just be combined together with a config flag. + + if password: + # Hash the password + # + # In mainline hashing of the password was moved further on in the registration + # flow, but we need it here for the AS use-case of shadow servers + password = await self.auth_handler.hash(password) + user_id = await self.registration_handler.appservice_register( - username, as_token, password_hash, display_name + username, as_token, password, display_name ) result = await self._create_registration_details(user_id, body) @@ -736,7 +764,7 @@ class RegisterRestServlet(RestServlet): (object) params: registration parameters, from which we pull device_id, initial_device_name and inhibit_login Returns: - defer.Deferred: (object) dictionary for response from /register + dictionary for response from /register """ result = {"user_id": user_id, "home_server": self.hs.hostname} if not params.get("inhibit_login", False): diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py
index 89002ffbff..18c75738f8 100644 --- a/synapse/rest/client/v2_alpha/relations.py +++ b/synapse/rest/client/v2_alpha/relations.py
@@ -22,7 +22,7 @@ any time to reflect changes in the MSC. import logging from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.errors import SynapseError +from synapse.api.errors import ShadowBanError, SynapseError from synapse.http.servlet import ( RestServlet, parse_integer, @@ -35,6 +35,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.util.stringutils import random_string from ._base import client_patterns @@ -60,7 +61,7 @@ class RelationSendServlet(RestServlet): ) def __init__(self, hs): - super(RelationSendServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.event_creation_handler = hs.get_event_creation_handler() self.txns = HttpTransactionCache(hs) @@ -111,11 +112,18 @@ class RelationSendServlet(RestServlet): "sender": requester.user.to_string(), } - event, _ = await self.event_creation_handler.create_and_send_nonmember_event( - requester, event_dict=event_dict, txn_id=txn_id - ) + try: + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + event_id = event.event_id + except ShadowBanError: + event_id = "$" + random_string(43) - return 200, {"event_id": event.event_id} + return 200, {"event_id": event_id} class RelationPaginationServlet(RestServlet): @@ -130,7 +138,7 @@ class RelationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() @@ -225,7 +233,7 @@ class RelationAggregationPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.event_handler = hs.get_event_handler() @@ -303,7 +311,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): ) def __init__(self, hs): - super(RelationAggregationGroupPaginationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py
index e15927c4ea..215d619ca1 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py
@@ -32,7 +32,7 @@ class ReportEventRestServlet(RestServlet): PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$") def __init__(self, hs): - super(ReportEventRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py
index 59529707df..53de97923f 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py
@@ -37,7 +37,7 @@ class RoomKeysServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -248,7 +248,7 @@ class RoomKeysNewVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysNewVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() @@ -301,7 +301,7 @@ class RoomKeysVersionServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(RoomKeysVersionServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler() diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
index f357015a70..bf030e0ff4 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py
@@ -15,13 +15,14 @@ import logging -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, ShadowBanError, SynapseError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.util import stringutils from ._base import client_patterns @@ -52,7 +53,7 @@ class RoomUpgradeRestServlet(RestServlet): ) def __init__(self, hs): - super(RoomUpgradeRestServlet, self).__init__() + super().__init__() self._hs = hs self._room_creation_handler = hs.get_room_creation_handler() self._auth = hs.get_auth() @@ -62,7 +63,6 @@ class RoomUpgradeRestServlet(RestServlet): content = parse_json_object_from_request(request) assert_params_in_dict(content, ("new_version",)) - new_version = content["new_version"] new_version = KNOWN_ROOM_VERSIONS.get(content["new_version"]) if new_version is None: @@ -72,9 +72,13 @@ class RoomUpgradeRestServlet(RestServlet): Codes.UNSUPPORTED_ROOM_VERSION, ) - new_room_id = await self._room_creation_handler.upgrade_room( - requester, room_id, new_version - ) + try: + new_room_id = await self._room_creation_handler.upgrade_room( + requester, room_id, new_version + ) + except ShadowBanError: + # Generate a random room ID. + new_room_id = stringutils.random_string(18) ret = {"replacement_room": new_room_id} diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py
index db829f3098..bc4f43639a 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py
@@ -36,7 +36,7 @@ class SendToDeviceRestServlet(servlet.RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(SendToDeviceRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py new file mode 100644
index 0000000000..c866d5151c --- /dev/null +++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Half-Shot +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.api.errors import Codes, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class UserSharedRoomsServlet(RestServlet): + """ + GET /uk.half-shot.msc2666/user/shared_rooms/{user_id} HTTP/1.1 + """ + + PATTERNS = client_patterns( + "/uk.half-shot.msc2666/user/shared_rooms/(?P<user_id>[^/]*)", + releases=(), # This is an unstable feature + ) + + def __init__(self, hs): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.user_directory_active = hs.config.update_user_directory + + async def on_GET(self, request, user_id): + + if not self.user_directory_active: + raise SynapseError( + code=400, + msg="The user directory is disabled on this server. Cannot determine shared rooms.", + errcode=Codes.FORBIDDEN, + ) + + UserID.from_string(user_id) + + requester = await self.auth.get_user_by_req(request) + if user_id == requester.user.to_string(): + raise SynapseError( + code=400, + msg="You cannot request a list of shared rooms with yourself", + errcode=Codes.FORBIDDEN, + ) + rooms = await self.store.get_shared_rooms_for_users( + requester.user.to_string(), user_id + ) + + return 200, {"joined": list(rooms)} + + +def register_servlets(hs, http_server): + UserSharedRoomsServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index a5c24fbd63..6779df952f 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py
@@ -16,8 +16,6 @@ import itertools import logging -from canonicaljson import json - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken +from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -75,9 +74,10 @@ class SyncRestServlet(RestServlet): ALLOWED_PRESENCE = {"online", "offline", "unavailable"} def __init__(self, hs): - super(SyncRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() + self.store = hs.get_datastore() self.sync_handler = hs.get_sync_handler() self.clock = hs.get_clock() self.filtering = hs.get_filtering() @@ -125,7 +125,7 @@ class SyncRestServlet(RestServlet): filter_collection = DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: - filter_object = json.loads(filter_id) + filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( filter_object, self.hs.config.filter_timeline_limit ) @@ -152,10 +152,9 @@ class SyncRestServlet(RestServlet): device_id=device_id, ) + since_token = None if since is not None: - since_token = StreamToken.from_string(since) - else: - since_token = None + since_token = await StreamToken.from_string(self.store, since) # send any outstanding server notices to the user. await self._server_notices_sender.on_user_syncing(user.to_string()) @@ -237,7 +236,7 @@ class SyncRestServlet(RestServlet): "leave": sync_result.groups.leave, }, "device_one_time_keys_count": sync_result.device_one_time_keys_count, - "next_batch": sync_result.next_batch.to_string(), + "next_batch": await sync_result.next_batch.to_string(self.store), } @staticmethod @@ -414,7 +413,7 @@ class SyncRestServlet(RestServlet): result = { "timeline": { "events": serialized_timeline, - "prev_batch": room.timeline.prev_batch.to_string(), + "prev_batch": await room.timeline.prev_batch.to_string(self.store), "limited": room.timeline.limited, }, "state": {"events": serialized_state}, @@ -426,6 +425,7 @@ class SyncRestServlet(RestServlet): result["ephemeral"] = {"events": ephemeral_events} result["unread_notifications"] = room.unread_notifications result["summary"] = room.summary + result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py
index a3f12e8a77..bf3a79db44 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py
@@ -31,7 +31,7 @@ class TagListServlet(RestServlet): PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags") def __init__(self, hs): - super(TagListServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -56,7 +56,7 @@ class TagServlet(RestServlet): ) def __init__(self, hs): - super(TagServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py
index 23709960ad..0c127a1b5f 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py
@@ -28,7 +28,7 @@ class ThirdPartyProtocolsServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): - super(ThirdPartyProtocolsServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -44,7 +44,7 @@ class ThirdPartyProtocolServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): - super(ThirdPartyProtocolServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -65,7 +65,7 @@ class ThirdPartyUserServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): - super(ThirdPartyUserServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() @@ -87,7 +87,7 @@ class ThirdPartyLocationServlet(RestServlet): PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): - super(ThirdPartyLocationServlet, self).__init__() + super().__init__() self.auth = hs.get_auth() self.appservice_handler = hs.get_application_service_handler() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py
index 83f3b6b70a..79317c74ba 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py
@@ -28,7 +28,7 @@ class TokenRefreshRestServlet(RestServlet): PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): - super(TokenRefreshRestServlet, self).__init__() + super().__init__() async def on_POST(self, request): raise AuthError(403, "tokenrefresh is no longer supported.") diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py
index 6e8300d6a5..5d4be8adaf 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py
@@ -39,7 +39,7 @@ class UserDirectorySearchRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): server """ - super(UserDirectorySearchRestServlet, self).__init__() + super().__init__() self.hs = hs self.auth = hs.get_auth() self.user_directory_handler = hs.get_user_directory_handler() diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index b1999d051b..c9b9e7f5ff 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py
@@ -19,6 +19,7 @@ import logging import re +from synapse.api.constants import RoomCreationPreset from synapse.http.servlet import RestServlet logger = logging.getLogger(__name__) @@ -28,9 +29,23 @@ class VersionsRestServlet(RestServlet): PATTERNS = [re.compile("^/_matrix/client/versions$")] def __init__(self, hs): - super(VersionsRestServlet, self).__init__() + super().__init__() self.config = hs.config + # Calculate these once since they shouldn't change after start-up. + self.e2ee_forced_public = ( + RoomCreationPreset.PUBLIC_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + self.e2ee_forced_private = ( + RoomCreationPreset.PRIVATE_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + self.e2ee_forced_trusted_private = ( + RoomCreationPreset.TRUSTED_PRIVATE_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + def on_GET(self, request): return ( 200, @@ -63,6 +78,12 @@ class VersionsRestServlet(RestServlet): # Tchap does not currently assume this rule for r0.5.0 # XXX: Remove this when it does "m.lazy_load_members": True, + # Implements additional endpoints as described in MSC2666 + "uk.half-shot.msc2666": True, + # Whether new rooms will be set to encrypted or not (based on presets). + "io.element.e2ee_forced.public": self.e2ee_forced_public, + "io.element.e2ee_forced.private": self.e2ee_forced_private, + "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, }, }, ) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py
index 4386eb4e72..b3e4d5612e 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py
@@ -22,8 +22,6 @@ from os import path import jinja2 from jinja2 import TemplateNotFound -from twisted.internet import defer - from synapse.api.errors import NotFoundError, StoreError, SynapseError from synapse.config import ConfigError from synapse.http.server import DirectServeHtmlResource, respond_with_html @@ -135,7 +133,7 @@ class ConsentResource(DirectServeHtmlResource): else: qualified_user_id = UserID(username, self.hs.hostname).to_string() - u = await defer.maybeDeferred(self.store.get_user_by_id, qualified_user_id) + u = await self.store.get_user_by_id(qualified_user_id) if u is None: raise NotFoundError("Unknown user") diff --git a/synapse/rest/health.py b/synapse/rest/health.py new file mode 100644
index 0000000000..0170950bf3 --- /dev/null +++ b/synapse/rest/health.py
@@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.web.resource import Resource + + +class HealthResource(Resource): + """A resource that does nothing except return a 200 with a body of `OK`, + which can be used as a health check. + + Note: `SynapseRequest._should_log_request` ensures that requests to + `/health` do not get logged at INFO. + """ + + isLeaf = 1 + + def render_GET(self, request): + request.setHeader(b"Content-Type", b"text/plain") + return b"OK" diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 9b3f85b306..f843f02454 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -15,19 +15,19 @@ import logging from typing import Dict, Set -from canonicaljson import encode_canonical_json, json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher -from synapse.http.server import DirectServeJsonResource, respond_with_json_bytes +from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.util import json_decoder logger = logging.getLogger(__name__) class RemoteKey(DirectServeJsonResource): - """HTTP resource for retreiving the TLS certificate and NACL signature + """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks that the NACL signature for the remote server is valid. Returns a dict of @@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource): Supports individual GET APIs and a bulk query POST API. - Requsts: + Requests: GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1 @@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) + # If there is a cache miss, request the missing keys, then recurse (and + # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await self.fetcher.get_keys(cache_misses) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json.decode("utf-8")) + key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) @@ -223,4 +225,4 @@ class RemoteKey(DirectServeJsonResource): results = {"server_keys": signed_keys} - respond_with_json_bytes(request, 200, encode_canonical_json(results)) + respond_with_json(request, 200, results, canonical_json=True) diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 9a847130c0..6568e61829 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py
@@ -17,7 +17,9 @@ import logging import os import urllib +from typing import Awaitable +from twisted.internet.interfaces import IConsumer from twisted.protocols.basic import FileSender from synapse.api.errors import Codes, SynapseError, cs_error @@ -233,21 +235,21 @@ async def respond_with_responder( finish_request(request) -class Responder(object): +class Responder: """Represents a response that can be streamed to the requester. Responder is a context manager which *must* be used, so that any resources held can be cleaned up. """ - def write_to_consumer(self, consumer): + def write_to_consumer(self, consumer: IConsumer) -> Awaitable: """Stream response into consumer Args: - consumer (IConsumer) + consumer: The consumer to stream into. Returns: - Deferred: Resolves once the response has finished being written + Resolves once the response has finished being written """ pass @@ -258,7 +260,7 @@ class Responder(object): pass -class FileInfo(object): +class FileInfo: """Details about a requested/uploaded file. Attributes: diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py
index e25c382c9c..7447eeaebe 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py
@@ -33,7 +33,7 @@ def _wrap_in_base_path(func): return _wrapped -class MediaFilePaths(object): +class MediaFilePaths: """Describes where files are stored on disk. Most of the functions have a `*_rel` variant which returns a file path that @@ -80,7 +80,7 @@ class MediaFilePaths(object): self, server_name, file_id, width, height, content_type, method ): top_level_type, sub_type = content_type.split("/") - file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + file_name = "%i-%i-%s-%s-%s" % (width, height, top_level_type, sub_type, method) return os.path.join( "remote_thumbnail", server_name, @@ -92,6 +92,23 @@ class MediaFilePaths(object): remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel) + # Legacy path that was used to store thumbnails previously. + # Should be removed after some time, when most of the thumbnails are stored + # using the new path. + def remote_media_thumbnail_rel_legacy( + self, server_name, file_id, width, height, content_type + ): + top_level_type, sub_type = content_type.split("/") + file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type) + return os.path.join( + "remote_thumbnail", + server_name, + file_id[0:2], + file_id[2:4], + file_id[4:], + file_name, + ) + def remote_media_thumbnail_dir(self, server_name, file_id): return os.path.join( self.base_path, diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 45628c07b4..e1192b47cd 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py
@@ -18,10 +18,11 @@ import errno import logging import os import shutil -from typing import Dict, Tuple +from typing import IO, Dict, Optional, Tuple import twisted.internet.error import twisted.web.http +from twisted.web.http import Request from twisted.web.resource import Resource from synapse.api.errors import ( @@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string from ._base import ( FileInfo, + Responder, get_filename_from_headers, respond_404, respond_with_responder, @@ -51,7 +53,7 @@ from .media_storage import MediaStorage from .preview_url_resource import PreviewUrlResource from .storage_provider import StorageProviderWrapper from .thumbnail_resource import ThumbnailResource -from .thumbnailer import Thumbnailer +from .thumbnailer import Thumbnailer, ThumbnailError from .upload_resource import UploadResource logger = logging.getLogger(__name__) @@ -60,7 +62,7 @@ logger = logging.getLogger(__name__) UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000 -class MediaRepository(object): +class MediaRepository: def __init__(self, hs): self.hs = hs self.auth = hs.get_auth() @@ -135,20 +137,26 @@ class MediaRepository(object): self.recently_accessed_locals.add(media_id) async def create_content( - self, media_type, upload_name, content, content_length, auth_user - ): + self, + media_type: str, + upload_name: Optional[str], + content: IO, + content_length: int, + auth_user: str, + ) -> str: """Store uploaded content for a local user and return the mxc URL Args: - media_type(str): The content type of the file - upload_name(str): The name of the file + media_type: The content type of the file. + upload_name: The name of the file, if provided. content: A file like object that is the content to store - content_length(int): The length of the content - auth_user(str): The user_id of the uploader + content_length: The length of the content + auth_user: The user_id of the uploader Returns: - Deferred[str]: The mxc url of the stored content + The mxc url of the stored content """ + media_id = random_string(24) file_info = FileInfo(server_name=None, file_id=media_id) @@ -170,19 +178,20 @@ class MediaRepository(object): return "mxc://%s/%s" % (self.server_name, media_id) - async def get_local_media(self, request, media_id, name): + async def get_local_media( + self, request: Request, media_id: str, name: Optional[str] + ) -> None: """Responds to reqests for local media, if exists, or returns 404. Args: - request(twisted.web.http.Request) - media_id (str): The media ID of the content. (This is the same as + request: The incoming request. + media_id: The media ID of the content. (This is the same as the file_id for local content.) - name (str|None): Optional name that, if specified, will be used as + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ media_info = await self.store.get_local_media(media_id) if not media_info or media_info["quarantined_by"]: @@ -203,20 +212,20 @@ class MediaRepository(object): request, responder, media_type, media_length, upload_name ) - async def get_remote_media(self, request, server_name, media_id, name): + async def get_remote_media( + self, request: Request, server_name: str, media_id: str, name: Optional[str] + ) -> None: """Respond to requests for remote media. Args: - request(twisted.web.http.Request) - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). - name (str|None): Optional name that, if specified, will be used as + request: The incoming request. + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). + name: Optional name that, if specified, will be used as the filename in the Content-Disposition header of the response. Returns: - Deferred: Resolves once a response has successfully been written - to request + Resolves once a response has successfully been written to request """ if ( self.federation_domain_whitelist is not None @@ -245,17 +254,16 @@ class MediaRepository(object): else: respond_404(request) - async def get_remote_media_info(self, server_name, media_id): + async def get_remote_media_info(self, server_name: str, media_id: str) -> dict: """Gets the media info associated with the remote file, downloading if necessary. Args: - server_name (str): Remote server_name where the media originated. - media_id (str): The media ID of the content (as defined by the - remote server). + server_name: Remote server_name where the media originated. + media_id: The media ID of the content (as defined by the remote server). Returns: - Deferred[dict]: The media_info of the file + The media info of the file """ if ( self.federation_domain_whitelist is not None @@ -278,7 +286,9 @@ class MediaRepository(object): return media_info - async def _get_remote_media_impl(self, server_name, media_id): + async def _get_remote_media_impl( + self, server_name: str, media_id: str + ) -> Tuple[Optional[Responder], dict]: """Looks for media in local cache, if not there then attempt to download from remote server. @@ -288,7 +298,7 @@ class MediaRepository(object): remote server). Returns: - Deferred[(Responder, media_info)] + A tuple of responder and the media info of the file. """ media_info = await self.store.get_cached_remote_media(server_name, media_id) @@ -319,19 +329,21 @@ class MediaRepository(object): responder = await self.media_storage.fetch_media(file_info) return responder, media_info - async def _download_remote_file(self, server_name, media_id, file_id): + async def _download_remote_file( + self, server_name: str, media_id: str, file_id: str + ) -> dict: """Attempt to download the remote file from the given server name, using the given file_id as the local id. Args: - server_name (str): Originating server - media_id (str): The media ID of the content (as defined by the + server_name: Originating server + media_id: The media ID of the content (as defined by the remote server). This is different than the file_id, which is locally generated. - file_id (str): Local file ID + file_id: Local file ID Returns: - Deferred[MediaInfo] + The media info of the file. """ file_info = FileInfo(server_name=server_name, file_id=file_id) @@ -449,13 +461,30 @@ class MediaRepository(object): return t_byte_source async def generate_local_exact_thumbnail( - self, media_id, t_width, t_height, t_method, t_type, url_cache - ): + self, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + url_cache: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(None, media_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for local media %s using a method of %s and type of %s: %s", + media_id, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -495,14 +524,36 @@ class MediaRepository(object): return output_path + # Could not generate thumbnail. + return None + async def generate_remote_exact_thumbnail( - self, server_name, file_id, media_id, t_width, t_height, t_method, t_type - ): + self, + server_name: str, + file_id: str, + media_id: str, + t_width: int, + t_height: int, + t_method: str, + t_type: str, + ) -> Optional[str]: input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=False) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate a thumbnail for remote media %s from %s using a method of %s and type of %s: %s", + media_id, + server_name, + t_method, + t_type, + e, + ) + return None + t_byte_source = await defer_to_thread( self.hs.get_reactor(), self._generate_thumbnail, @@ -548,32 +599,52 @@ class MediaRepository(object): return output_path + # Could not generate thumbnail. + return None + async def _generate_thumbnails( - self, server_name, media_id, file_id, media_type, url_cache=False - ): + self, + server_name: Optional[str], + media_id: str, + file_id: str, + media_type: str, + url_cache: bool = False, + ) -> Optional[dict]: """Generate and store thumbnails for an image. Args: - server_name (str|None): The server name if remote media, else None if local - media_id (str): The media ID of the content. (This is the same as + server_name: The server name if remote media, else None if local + media_id: The media ID of the content. (This is the same as the file_id for local content) - file_id (str): Local file ID - media_type (str): The content type of the file - url_cache (bool): If we are thumbnailing images downloaded for the URL cache, + file_id: Local file ID + media_type: The content type of the file + url_cache: If we are thumbnailing images downloaded for the URL cache, used exclusively by the url previewer Returns: - Deferred[dict]: Dict with "width" and "height" keys of original image + Dict with "width" and "height" keys of original image or None if the + media cannot be thumbnailed. """ requirements = self._get_thumbnail_requirements(media_type) if not requirements: - return + return None input_path = await self.media_storage.ensure_media_is_in_local_cache( FileInfo(server_name, file_id, url_cache=url_cache) ) - thumbnailer = Thumbnailer(input_path) + try: + thumbnailer = Thumbnailer(input_path) + except ThumbnailError as e: + logger.warning( + "Unable to generate thumbnails for remote media %s from %s of type %s: %s", + media_id, + server_name, + media_type, + e, + ) + return None + m_width = thumbnailer.width m_height = thumbnailer.height @@ -584,7 +655,7 @@ class MediaRepository(object): m_height, self.max_image_pixels, ) - return + return None if thumbnailer.transpose_method is not None: m_width, m_height = await defer_to_thread( diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
index 66bc1c3360..a9586fb0b7 100644 --- a/synapse/rest/media/v1/media_storage.py +++ b/synapse/rest/media/v1/media_storage.py
@@ -12,13 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import contextlib -import inspect import logging import os import shutil -from typing import Optional +from typing import IO, TYPE_CHECKING, Any, Optional, Sequence from twisted.protocols.basic import FileSender @@ -26,28 +24,39 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable from synapse.util.file_consumer import BackgroundFileConsumer from ._base import FileInfo, Responder +from .filepath import MediaFilePaths + +if TYPE_CHECKING: + from synapse.server import HomeServer + + from .storage_provider import StorageProviderWrapper logger = logging.getLogger(__name__) -class MediaStorage(object): +class MediaStorage: """Responsible for storing/fetching files from local sources. Args: - hs (synapse.server.Homeserver) - local_media_directory (str): Base path where we store media on disk - filepaths (MediaFilePaths) - storage_providers ([StorageProvider]): List of StorageProvider that are - used to fetch and store files. + hs + local_media_directory: Base path where we store media on disk + filepaths + storage_providers: List of StorageProvider that are used to fetch and store files. """ - def __init__(self, hs, local_media_directory, filepaths, storage_providers): + def __init__( + self, + hs: "HomeServer", + local_media_directory: str, + filepaths: MediaFilePaths, + storage_providers: Sequence["StorageProviderWrapper"], + ): self.hs = hs self.local_media_directory = local_media_directory self.filepaths = filepaths self.storage_providers = storage_providers - async def store_file(self, source, file_info: FileInfo) -> str: + async def store_file(self, source: IO, file_info: FileInfo) -> str: """Write `source` to the on disk media store, and also any other configured storage providers @@ -69,7 +78,7 @@ class MediaStorage(object): return fname @contextlib.contextmanager - def store_into_file(self, file_info): + def store_into_file(self, file_info: FileInfo): """Context manager used to get a file like object to write into, as described by file_info. @@ -85,7 +94,7 @@ class MediaStorage(object): error. Args: - file_info (FileInfo): Info about the file to store + file_info: Info about the file to store Example: @@ -105,11 +114,7 @@ class MediaStorage(object): async def finish(): for provider in self.storage_providers: - # store_file is supposed to return an Awaitable, but guard - # against improper implementations. - result = provider.store_file(path, file_info) - if inspect.isawaitable(result): - await result + await provider.store_file(path, file_info) finished_called[0] = True @@ -136,21 +141,34 @@ class MediaStorage(object): Returns: Returns a Responder if the file was found, otherwise None. """ + paths = [self._file_info_to_path(file_info)] - path = self._file_info_to_path(file_info) - local_path = os.path.join(self.local_media_directory, path) - if os.path.exists(local_path): - return FileResponder(open(local_path, "rb")) + # fallback for remote thumbnails with no method in the filename + if file_info.thumbnail and file_info.server_name: + paths.append( + self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + ) + + for path in paths: + local_path = os.path.join(self.local_media_directory, path) + if os.path.exists(local_path): + logger.debug("responding with local file %s", local_path) + return FileResponder(open(local_path, "rb")) + logger.debug("local file %s did not exist", local_path) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. - if inspect.isawaitable(res): - res = await res - if res: - logger.debug("Streaming %s from %s", path, provider) - return res + for path in paths: + res = await provider.fetch(path, file_info) # type: Any + if res: + logger.debug("Streaming %s from %s", path, provider) + return res + logger.debug("%s not found on %s", path, provider) return None @@ -169,16 +187,26 @@ class MediaStorage(object): if os.path.exists(local_path): return local_path + # Fallback for paths without method names + # Should be removed in the future + if file_info.thumbnail and file_info.server_name: + legacy_path = self.filepaths.remote_media_thumbnail_rel_legacy( + server_name=file_info.server_name, + file_id=file_info.file_id, + width=file_info.thumbnail_width, + height=file_info.thumbnail_height, + content_type=file_info.thumbnail_type, + ) + legacy_local_path = os.path.join(self.local_media_directory, legacy_path) + if os.path.exists(legacy_local_path): + return legacy_local_path + dirname = os.path.dirname(local_path) if not os.path.exists(dirname): os.makedirs(dirname) for provider in self.storage_providers: - res = provider.fetch(path, file_info) - # Fetch is supposed to return an Awaitable, but guard against - # improper implementations. - if inspect.isawaitable(res): - res = await res + res = await provider.fetch(path, file_info) # type: Any if res: with res: consumer = BackgroundFileConsumer( @@ -190,17 +218,11 @@ class MediaStorage(object): raise Exception("file could not be found") - def _file_info_to_path(self, file_info): + def _file_info_to_path(self, file_info: FileInfo) -> str: """Converts file_info into a relative path. The path is suitable for storing files under a directory, e.g. used to store files on local FS under the base media repository directory. - - Args: - file_info (FileInfo) - - Returns: - str """ if file_info.url_cache: if file_info.thumbnail: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 13d1a6d2ed..dce6c4d168 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -27,9 +27,7 @@ from typing import Dict, Optional from urllib import parse as urlparse import attr -from canonicaljson import json -from twisted.internet import defer from twisted.internet.error import DNSLookupError from synapse.api.errors import Codes, SynapseError @@ -43,6 +41,7 @@ from synapse.http.servlet import parse_integer, parse_string from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.media.v1._base import get_filename_from_headers +from synapse.util import json_encoder from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.stringutils import random_string @@ -103,7 +102,7 @@ for endpoint, globs in _oembed_globs.items(): _oembed_patterns[re.compile(pattern)] = endpoint -@attr.s +@attr.s(slots=True) class OEmbedResult: # Either HTML content or URL must be provided. html = attr.ib(type=Optional[str]) @@ -228,19 +227,19 @@ class PreviewUrlResource(DirectServeJsonResource): else: logger.info("Returning cached response") - og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe)) + og = await make_deferred_yieldable(observable.observe()) respond_with_json_bytes(request, 200, og, send_cors=True) - async def _do_preview(self, url, user, ts): + async def _do_preview(self, url: str, user: str, ts: int) -> bytes: """Check the db, and download the URL and build a preview Args: - url (str): - user (str): - ts (int): + url: The URL to preview. + user: The user requesting the preview. + ts: The timestamp requested for the preview. Returns: - Deferred[bytes]: json-encoded og data + json-encoded og data """ # check the URL cache in the DB (which will also provide us with # historical previews, if we have any) @@ -355,7 +354,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Calculated OG for %s as %s", url, og) - jsonog = json.dumps(og) + jsonog = json_encoder.encode(og) # store OG in history-aware DB cache await self.store.store_url_cache( @@ -451,7 +450,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.warning("Error downloading oEmbed metadata from %s: %r", url, e) raise OEmbedError() from e - async def _download_url(self, url, user): + async def _download_url(self, url: str, user): # TODO: we should probably honour robots.txt... except in practice # we're most likely being explicitly triggered by a human rather than a # bot, so are we really a robot? @@ -461,7 +460,7 @@ class PreviewUrlResource(DirectServeJsonResource): file_info = FileInfo(server_name=None, file_id=file_id, url_cache=True) # If this URL can be accessed via oEmbed, use that instead. - url_to_download = url + url_to_download = url # type: Optional[str] oembed_url = self._get_oembed_url(url) if oembed_url: # The result might be a new URL to download, or it might be HTML content. @@ -521,9 +520,15 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: we should calculate a proper expiration based on the # Cache-Control and Expire headers. But for now, assume 1 hour. expires = ONE_HOUR - etag = headers["ETag"][0] if "ETag" in headers else None + etag = ( + headers[b"ETag"][0].decode("ascii") if b"ETag" in headers else None + ) else: - html_bytes = oembed_result.html.encode("utf-8") # type: ignore + # we can only get here if we did an oembed request and have an oembed_result.html + assert oembed_result.html is not None + assert oembed_url is not None + + html_bytes = oembed_result.html.encode("utf-8") with self.media_storage.store_into_file(file_info) as (f, fname, finish): f.write(html_bytes) await finish() @@ -586,7 +591,7 @@ class PreviewUrlResource(DirectServeJsonResource): logger.debug("Running url preview cache expiry") - if not (await self.store.db.updates.has_completed_background_updates()): + if not (await self.store.db_pool.updates.has_completed_background_updates()): logger.info("Still running DB updates; skipping expiry") return diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 858680be26..18c9ed48d6 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,65 +13,66 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect import logging import os import shutil - -from twisted.internet import defer +from typing import Optional from synapse.config._base import Config from synapse.logging.context import defer_to_thread, run_in_background +from ._base import FileInfo, Responder from .media_storage import FileResponder logger = logging.getLogger(__name__) -class StorageProvider(object): +class StorageProvider: """A storage provider is a service that can store uploaded media and retrieve them. """ - def store_file(self, path, file_info): + async def store_file(self, path: str, file_info: FileInfo): """Store the file described by file_info. The actual contents can be retrieved by reading the file in file_info.upload_path. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) - - Returns: - Deferred + path: Relative path of file in local cache + file_info: The metadata of the file. """ - pass - def fetch(self, path, file_info): + async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: """Attempt to fetch the file described by file_info and stream it into writer. Args: - path (str): Relative path of file in local cache - file_info (FileInfo) + path: Relative path of file in local cache + file_info: The metadata of the file. Returns: - Deferred(Responder): Returns a Responder if the provider has the file, - otherwise returns None. + Returns a Responder if the provider has the file, otherwise returns None. """ - pass class StorageProviderWrapper(StorageProvider): """Wraps a storage provider and provides various config options Args: - backend (StorageProvider) - store_local (bool): Whether to store new local files or not. - store_synchronous (bool): Whether to wait for file to be successfully + backend: The storage provider to wrap. + store_local: Whether to store new local files or not. + store_synchronous: Whether to wait for file to be successfully uploaded, or todo the upload in the background. - store_remote (bool): Whether remote media should be uploaded + store_remote: Whether remote media should be uploaded """ - def __init__(self, backend, store_local, store_synchronous, store_remote): + def __init__( + self, + backend: StorageProvider, + store_local: bool, + store_synchronous: bool, + store_remote: bool, + ): self.backend = backend self.store_local = store_local self.store_synchronous = store_synchronous @@ -80,28 +81,38 @@ class StorageProviderWrapper(StorageProvider): def __str__(self): return "StorageProviderWrapper[%s]" % (self.backend,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): if not file_info.server_name and not self.store_local: - return defer.succeed(None) + return None if file_info.server_name and not self.store_remote: - return defer.succeed(None) + return None if self.store_synchronous: - return self.backend.store_file(path, file_info) + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = self.backend.store_file(path, file_info) + if inspect.isawaitable(result): + return await result else: # TODO: Handle errors. - def store(): + async def store(): try: - return self.backend.store_file(path, file_info) + result = self.backend.store_file(path, file_info) + if inspect.isawaitable(result): + return await result except Exception: logger.exception("Error storing file") run_in_background(store) - return defer.succeed(None) + return None - def fetch(self, path, file_info): - return self.backend.fetch(path, file_info) + async def fetch(self, path, file_info): + # store_file is supposed to return an Awaitable, but guard + # against improper implementations. + result = self.backend.fetch(path, file_info) + if inspect.isawaitable(result): + return await result class FileStorageProviderBackend(StorageProvider): @@ -120,7 +131,7 @@ class FileStorageProviderBackend(StorageProvider): def __str__(self): return "FileStorageProviderBackend[%s]" % (self.base_directory,) - def store_file(self, path, file_info): + async def store_file(self, path, file_info): """See StorageProvider.store_file""" primary_fname = os.path.join(self.cache_directory, path) @@ -130,11 +141,11 @@ class FileStorageProviderBackend(StorageProvider): if not os.path.exists(dirname): os.makedirs(dirname) - return defer_to_thread( + return await defer_to_thread( self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname ) - def fetch(self, path, file_info): + async def fetch(self, path, file_info): """See StorageProvider.fetch""" backup_fname = os.path.join(self.base_directory, path) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index a83535b97b..30421b663a 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -16,6 +16,7 @@ import logging +from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string @@ -173,7 +174,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _select_or_generate_remote_thumbnail( self, @@ -235,7 +236,7 @@ class ThumbnailResource(DirectServeJsonResource): await respond_with_file(request, desired_type, file_path) else: logger.warning("Failed to generate thumbnail") - respond_404(request) + raise SynapseError(400, "Failed to generate thumbnail.") async def _respond_remote_thumbnail( self, request, server_name, media_id, width, height, method, m_type diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 7126997134..32a8e4f960 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py
@@ -15,7 +15,7 @@ import logging from io import BytesIO -from PIL import Image as Image +from PIL import Image logger = logging.getLogger(__name__) @@ -31,12 +31,22 @@ EXIF_TRANSPOSE_MAPPINGS = { } -class Thumbnailer(object): +class ThumbnailError(Exception): + """An error occurred generating a thumbnail.""" + + +class Thumbnailer: FORMATS = {"image/jpeg": "JPEG", "image/png": "PNG"} def __init__(self, input_path): - self.image = Image.open(input_path) + try: + self.image = Image.open(input_path) + except OSError as e: + # If an error occurs opening the image, a thumbnail won't be able to + # be generated. + raise ThumbnailError from e + self.width, self.height = self.image.size self.transpose_method = None try: @@ -73,7 +83,7 @@ class Thumbnailer(object): Args: max_width: The largest possible width. - max_height: The larget possible height. + max_height: The largest possible height. """ if max_width * self.height < max_height * self.width: @@ -107,7 +117,7 @@ class Thumbnailer(object): Args: max_width: The largest possible width. - max_height: The larget possible height. + max_height: The largest possible height. Returns: BytesIO: the bytes of the encoded image ready to be written to disk diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index 3ebf7a68e6..d76f7389e1 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py
@@ -63,6 +63,10 @@ class UploadResource(DirectServeJsonResource): msg="Invalid UTF-8 filename parameter: %r" % (upload_name), code=400 ) + # If the name is falsey (e.g. an empty byte string) ensure it is None. + else: + upload_name = None + headers = request.requestHeaders if headers.hasHeader(b"Content-Type"): diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index c10188a5d7..f6668fb5e3 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py
@@ -13,10 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.python import failure -from synapse.api.errors import SynapseError -from synapse.http.server import DirectServeHtmlResource, return_html_error +from synapse.http.server import DirectServeHtmlResource class SAML2ResponseResource(DirectServeHtmlResource): @@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource): def __init__(self, hs): super().__init__() self._saml_handler = hs.get_saml_handler() - self._error_html_template = hs.config.saml2.saml2_error_html_template async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - f = failure.Failure( - SynapseError(400, "Unexpected GET request on /saml2/authn_response") + self._saml_handler._render_error( + request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) - return_html_error(f, request, self._error_html_template) async def _async_render_POST(self, request): - try: - await self._saml_handler.handle_saml_response(request) - except Exception: - f = failure.Failure() - return_html_error(f, request, self._error_html_template) + await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/synapse/__init__.py b/synapse/rest/synapse/__init__.py new file mode 100644
index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py new file mode 100644
index 0000000000..c0b733488b --- /dev/null +++ b/synapse/rest/synapse/client/__init__.py
@@ -0,0 +1,14 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py new file mode 100644
index 0000000000..9e4fbc0cbd --- /dev/null +++ b/synapse/rest/synapse/client/password_reset.py
@@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING, Tuple + +from twisted.web.http import Request + +from synapse.api.errors import ThreepidValidationError +from synapse.config.emailconfig import ThreepidBehaviour +from synapse.http.server import DirectServeHtmlResource +from synapse.http.servlet import parse_string +from synapse.util.stringutils import assert_valid_client_secret + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PasswordResetSubmitTokenResource(DirectServeHtmlResource): + """Handles 3PID validation token submission + + This resource gets mounted under /_synapse/client/password_reset/email/submit_token + """ + + isLeaf = 1 + + def __init__(self, hs: "HomeServer"): + """ + Args: + hs: server + """ + super().__init__() + + self.clock = hs.get_clock() + self.store = hs.get_datastore() + + self._local_threepid_handling_disabled_due_to_email_config = ( + hs.config.local_threepid_handling_disabled_due_to_email_config + ) + self._confirmation_email_template = ( + hs.config.email_password_reset_template_confirmation_html + ) + self._email_password_reset_template_success_html = ( + hs.config.email_password_reset_template_success_html_content + ) + self._failure_email_template = ( + hs.config.email_password_reset_template_failure_html + ) + + # This resource should not be mounted if threepid behaviour is not LOCAL + assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL + + async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + assert_valid_client_secret(client_secret) + + # Show a confirmation page, just in case someone accidentally clicked this link when + # they didn't mean to + template_vars = { + "sid": sid, + "token": token, + "client_secret": client_secret, + } + return ( + 200, + self._confirmation_email_template.render(**template_vars).encode("utf-8"), + ) + + async def _async_render_POST(self, request: Request) -> Tuple[int, bytes]: + sid = parse_string(request, "sid", required=True) + token = parse_string(request, "token", required=True) + client_secret = parse_string(request, "client_secret", required=True) + + # Attempt to validate a 3PID session + try: + # Mark the session as valid + next_link = await self.store.validate_threepid_session( + sid, client_secret, token, self.clock.time_msec() + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warning( + "Not redirecting to next_link as it is a local file: address" + ) + else: + next_link_bytes = next_link.encode("utf-8") + request.setHeader("Location", next_link_bytes) + return ( + 302, + ( + b'You are being redirected to <a src="%s">%s</a>.' + % (next_link_bytes, next_link_bytes) + ), + ) + + # Otherwise show the success template + html_bytes = self._email_password_reset_template_success_html.encode( + "utf-8" + ) + status_code = 200 + except ThreepidValidationError as e: + status_code = e.code + + # Show a failure page with a reason + template_vars = {"failure_reason": e.msg} + html_bytes = self._failure_email_template.render(**template_vars).encode( + "utf-8" + ) + + return status_code, html_bytes diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py
index 20177b44e7..f591cc6c5c 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py
@@ -13,17 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from twisted.web.resource import Resource from synapse.http.server import set_cors_headers +from synapse.util import json_encoder logger = logging.getLogger(__name__) -class WellKnownBuilder(object): +class WellKnownBuilder: """Utility to construct the well-known response Args: @@ -67,4 +67,4 @@ class WellKnownResource(Resource): logger.debug("returning: %s", r) request.setHeader(b"Content-Type", b"application/json") - return json.dumps(r).encode("utf-8") + return json_encoder.encode(r).encode("utf-8")