diff options
Diffstat (limited to 'synapse/rest')
44 files changed, 953 insertions, 345 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 3a24d31d1b..e6110ad9b1 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import ( read_marker, receipts, register, + relations, report_event, room_keys, room_upgrade_rest_servlet, @@ -115,6 +116,7 @@ class ClientRestResource(JsonResource): room_upgrade_rest_servlet.register_servlets(hs, client_resource) capabilities.register_servlets(hs, client_resource) account_validity.register_servlets(hs, client_resource) + relations.register_servlets(hs, client_resource) # moving to /_synapse/admin synapse.rest.admin.register_servlets_for_client_rest_resource( diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 744d85594f..d6c4dcdb18 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -822,10 +822,16 @@ class AdminRestResource(JsonResource): def __init__(self, hs): JsonResource.__init__(self, hs, canonical_json=False) + register_servlets(hs, self) - register_servlets_for_client_rest_resource(hs, self) - SendServerNoticeServlet(hs).register(self) - VersionServlet(hs).register(self) + +def register_servlets(hs, http_server): + """ + Register all the admin servlets. + """ + register_servlets_for_client_rest_resource(hs, http_server) + SendServerNoticeServlet(hs).register(http_server) + VersionServlet(hs).register(http_server) def register_servlets_for_client_rest_resource(hs, http_server): diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py deleted file mode 100644 index dc63b661c0..0000000000 --- a/synapse/rest/client/v1/base.py +++ /dev/null @@ -1,65 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""This module contains base REST classes for constructing client v1 servlets. -""" - -import logging -import re - -from synapse.api.urls import CLIENT_API_PREFIX -from synapse.http.servlet import RestServlet -from synapse.rest.client.transactions import HttpTransactionCache - -logger = logging.getLogger(__name__) - - -def client_path_patterns(path_regex, releases=(0,), include_in_unstable=True): - """Creates a regex compiled client path with the correct client path - prefix. - - Args: - path_regex (str): The regex string to match. This should NOT have a ^ - as this will be prefixed. - Returns: - SRE_Pattern - """ - patterns = [re.compile("^" + CLIENT_API_PREFIX + "/api/v1" + path_regex)] - if include_in_unstable: - unstable_prefix = CLIENT_API_PREFIX + "/unstable" - patterns.append(re.compile("^" + unstable_prefix + path_regex)) - for release in releases: - new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) - patterns.append(re.compile("^" + new_prefix + path_regex)) - return patterns - - -class ClientV1RestServlet(RestServlet): - """A base Synapse REST Servlet for the client version 1 API. - """ - - # This subclass was presumably created to allow the auth for the v1 - # protocol version to be different, however this behaviour was removed. - # it may no longer be necessary - - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): - """ - self.hs = hs - self.builder_factory = hs.get_event_builder_factory() - self.auth = hs.get_auth() - self.txns = HttpTransactionCache(hs) diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 0220acf644..0035182bb9 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -19,11 +19,10 @@ import logging from twisted.internet import defer from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import RoomAlias -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) @@ -33,13 +32,14 @@ def register_servlets(hs, http_server): ClientAppserviceDirectoryListServer(hs).register(http_server) -class ClientDirectoryServer(ClientV1RestServlet): - PATTERNS = client_path_patterns("/directory/room/(?P<room_alias>[^/]*)$") +class ClientDirectoryServer(RestServlet): + PATTERNS = client_patterns("/directory/room/(?P<room_alias>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryServer, self).__init__(hs) + super(ClientDirectoryServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_alias): @@ -120,13 +120,14 @@ class ClientDirectoryServer(ClientV1RestServlet): defer.returnValue((200, {})) -class ClientDirectoryListServer(ClientV1RestServlet): - PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$") +class ClientDirectoryListServer(RestServlet): + PATTERNS = client_patterns("/directory/list/room/(?P<room_id>[^/]*)$", v1=True) def __init__(self, hs): - super(ClientDirectoryListServer, self).__init__(hs) + super(ClientDirectoryListServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -162,15 +163,16 @@ class ClientDirectoryListServer(ClientV1RestServlet): defer.returnValue((200, {})) -class ClientAppserviceDirectoryListServer(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$" +class ClientAppserviceDirectoryListServer(RestServlet): + PATTERNS = client_patterns( + "/directory/list/appservice/(?P<network_id>[^/]*)/(?P<room_id>[^/]*)$", v1=True ) def __init__(self, hs): - super(ClientAppserviceDirectoryListServer, self).__init__(hs) + super(ClientAppserviceDirectoryListServer, self).__init__() self.store = hs.get_datastore() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() def on_PUT(self, request, network_id, room_id): content = parse_json_object_from_request(request) diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index c3b0a39ab7..84ca36270b 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -19,21 +19,22 @@ import logging from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class EventStreamRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/events$") +class EventStreamRestServlet(RestServlet): + PATTERNS = client_patterns("/events$", v1=True) DEFAULT_LONGPOLL_TIME_MS = 30000 def __init__(self, hs): - super(EventStreamRestServlet, self).__init__(hs) + super(EventStreamRestServlet, self).__init__() self.event_stream_handler = hs.get_event_stream_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -76,11 +77,11 @@ class EventStreamRestServlet(ClientV1RestServlet): # TODO: Unit test gets, with and without auth, with different kinds of events. -class EventRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/events/(?P<event_id>[^/]*)$") +class EventRestServlet(RestServlet): + PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True) def __init__(self, hs): - super(EventRestServlet, self).__init__(hs) + super(EventRestServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 3ead75cb77..0fe5f2d79b 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -15,19 +15,19 @@ from twisted.internet import defer -from synapse.http.servlet import parse_boolean +from synapse.http.servlet import RestServlet, parse_boolean +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.streams.config import PaginationConfig -from .base import ClientV1RestServlet, client_path_patterns - # TODO: Needs unit testing -class InitialSyncRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/initialSync$") +class InitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/initialSync$", v1=True) def __init__(self, hs): - super(InitialSyncRestServlet, self).__init__(hs) + super(InitialSyncRestServlet, self).__init__() self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 5180e9eaf1..3b60728628 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -29,12 +29,11 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.well_known import WellKnownBuilder from synapse.types import UserID, map_username_to_mxid_localpart from synapse.util.msisdn import phone_number_to_msisdn -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) @@ -81,15 +80,16 @@ def login_id_thirdparty_from_phone(identifier): } -class LoginRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login$") +class LoginRestServlet(RestServlet): + PATTERNS = client_patterns("/login$", v1=True) CAS_TYPE = "m.login.cas" SSO_TYPE = "m.login.sso" TOKEN_TYPE = "m.login.token" JWT_TYPE = "m.login.jwt" def __init__(self, hs): - super(LoginRestServlet, self).__init__(hs) + super(LoginRestServlet, self).__init__() + self.hs = hs self.jwt_enabled = hs.config.jwt_enabled self.jwt_secret = hs.config.jwt_secret self.jwt_algorithm = hs.config.jwt_algorithm @@ -371,7 +371,7 @@ class LoginRestServlet(ClientV1RestServlet): class CasRedirectServlet(RestServlet): - PATTERNS = client_path_patterns("/login/(cas|sso)/redirect") + PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) def __init__(self, hs): super(CasRedirectServlet, self).__init__() @@ -386,7 +386,7 @@ class CasRedirectServlet(RestServlet): b"redirectUrl": args[b"redirectUrl"][0] }).encode('ascii') hs_redirect_url = (self.cas_service_url + - b"/_matrix/client/api/v1/login/cas/ticket") + b"/_matrix/client/r0/login/cas/ticket") service_param = urllib.parse.urlencode({ b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param) }).encode('ascii') @@ -394,27 +394,27 @@ class CasRedirectServlet(RestServlet): finish_request(request) -class CasTicketServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/login/cas/ticket", releases=()) +class CasTicketServlet(RestServlet): + PATTERNS = client_patterns("/login/cas/ticket", v1=True) def __init__(self, hs): - super(CasTicketServlet, self).__init__(hs) + super(CasTicketServlet, self).__init__() self.cas_server_url = hs.config.cas_server_url self.cas_service_url = hs.config.cas_service_url self.cas_required_attributes = hs.config.cas_required_attributes self._sso_auth_handler = SSOAuthHandler(hs) + self._http_client = hs.get_simple_http_client() @defer.inlineCallbacks def on_GET(self, request): client_redirect_url = parse_string(request, "redirectUrl", required=True) - http_client = self.hs.get_simple_http_client() uri = self.cas_server_url + "/proxyValidate" args = { "ticket": parse_string(request, "ticket", required=True), "service": self.cas_service_url } try: - body = yield http_client.get_raw(uri, args) + body = yield self._http_client.get_raw(uri, args) except PartialDownloadError as pde: # Twisted raises this error if the connection is closed, # even if that's being used old-http style to signal end-of-data diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py index 430c692336..b8064f261e 100644 --- a/synapse/rest/client/v1/logout.py +++ b/synapse/rest/client/v1/logout.py @@ -17,19 +17,18 @@ import logging from twisted.internet import defer -from synapse.api.errors import AuthError - -from .base import ClientV1RestServlet, client_path_patterns +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -class LogoutRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/logout$") +class LogoutRestServlet(RestServlet): + PATTERNS = client_patterns("/logout$", v1=True) def __init__(self, hs): - super(LogoutRestServlet, self).__init__(hs) - self._auth = hs.get_auth() + super(LogoutRestServlet, self).__init__() + self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() @@ -38,32 +37,25 @@ class LogoutRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - try: - requester = yield self.auth.get_user_by_req(request) - except AuthError: - # this implies the access token has already been deleted. - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + requester = yield self.auth.get_user_by_req(request) + + if requester.device_id is None: + # the acccess token wasn't associated with a device. + # Just delete the access token + access_token = self.auth.get_access_token_from_request(request) + yield self._auth_handler.delete_access_token(access_token) else: - if requester.device_id is None: - # the acccess token wasn't associated with a device. - # Just delete the access token - access_token = self._auth.get_access_token_from_request(request) - yield self._auth_handler.delete_access_token(access_token) - else: - yield self._device_handler.delete_device( - requester.user.to_string(), requester.device_id) + yield self._device_handler.delete_device( + requester.user.to_string(), requester.device_id) defer.returnValue((200, {})) -class LogoutAllRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/logout/all$") +class LogoutAllRestServlet(RestServlet): + PATTERNS = client_patterns("/logout/all$", v1=True) def __init__(self, hs): - super(LogoutAllRestServlet, self).__init__(hs) + super(LogoutAllRestServlet, self).__init__() self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index 045d5a20ac..e263da3cb7 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -23,21 +23,22 @@ from twisted.internet import defer from synapse.api.errors import AuthError, SynapseError from synapse.handlers.presence import format_user_presence_state -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class PresenceStatusRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status") +class PresenceStatusRestServlet(RestServlet): + PATTERNS = client_patterns("/presence/(?P<user_id>[^/]*)/status", v1=True) def __init__(self, hs): - super(PresenceStatusRestServlet, self).__init__(hs) + super(PresenceStatusRestServlet, self).__init__() + self.hs = hs self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index eac1966c5e..e15d9d82a6 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -16,18 +16,19 @@ """ This module contains REST servlets to do with profile: /profile/<paths> """ from twisted.internet import defer -from synapse.http.servlet import parse_json_object_from_request +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.types import UserID -from .base import ClientV1RestServlet, client_path_patterns - -class ProfileDisplaynameRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/displayname") +class ProfileDisplaynameRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/displayname", v1=True) def __init__(self, hs): - super(ProfileDisplaynameRestServlet, self).__init__(hs) + super(ProfileDisplaynameRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -71,12 +72,14 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): return (200, {}) -class ProfileAvatarURLRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)/avatar_url") +class ProfileAvatarURLRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)/avatar_url", v1=True) def __init__(self, hs): - super(ProfileAvatarURLRestServlet, self).__init__(hs) + super(ProfileAvatarURLRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): @@ -119,12 +122,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): return (200, {}) -class ProfileRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/profile/(?P<user_id>[^/]*)") +class ProfileRestServlet(RestServlet): + PATTERNS = client_patterns("/profile/(?P<user_id>[^/]*)", v1=True) def __init__(self, hs): - super(ProfileRestServlet, self).__init__(hs) + super(ProfileRestServlet, self).__init__() + self.hs = hs self.profile_handler = hs.get_profile_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, user_id): diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py index 506ec95ddd..3d6326fe2f 100644 --- a/synapse/rest/client/v1/push_rule.py +++ b/synapse/rest/client/v1/push_rule.py @@ -21,22 +21,22 @@ from synapse.api.errors import ( SynapseError, UnrecognizedRequestError, ) -from synapse.http.servlet import parse_json_value_from_request, parse_string +from synapse.http.servlet import RestServlet, parse_json_value_from_request, parse_string from synapse.push.baserules import BASE_RULE_IDS from synapse.push.clientformat import format_push_rules_for_user from synapse.push.rulekinds import PRIORITY_CLASS_MAP +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException -from .base import ClientV1RestServlet, client_path_patterns - -class PushRuleRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/(?P<path>pushrules/.*)$") +class PushRuleRestServlet(RestServlet): + PATTERNS = client_patterns("/(?P<path>pushrules/.*)$", v1=True) SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( "Unrecognised request: You probably wanted a trailing slash") def __init__(self, hs): - super(PushRuleRestServlet, self).__init__(hs) + super(PushRuleRestServlet, self).__init__() + self.auth = hs.get_auth() self.store = hs.get_datastore() self.notifier = hs.get_notifier() self._is_worker = hs.config.worker_app is not None diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py index 4c07ae7f45..15d860db37 100644 --- a/synapse/rest/client/v1/pusher.py +++ b/synapse/rest/client/v1/pusher.py @@ -26,17 +26,18 @@ from synapse.http.servlet import ( parse_string, ) from synapse.push import PusherConfigException - -from .base import ClientV1RestServlet, client_path_patterns +from synapse.rest.client.v2_alpha._base import client_patterns logger = logging.getLogger(__name__) -class PushersRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/pushers$") +class PushersRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers$", v1=True) def __init__(self, hs): - super(PushersRestServlet, self).__init__(hs) + super(PushersRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -69,11 +70,13 @@ class PushersRestServlet(ClientV1RestServlet): return 200, {} -class PushersSetRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/pushers/set$") +class PushersSetRestServlet(RestServlet): + PATTERNS = client_patterns("/pushers/set$", v1=True) def __init__(self, hs): - super(PushersSetRestServlet, self).__init__(hs) + super(PushersSetRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() @@ -141,7 +144,7 @@ class PushersRemoveRestServlet(RestServlet): """ To allow pusher to be delete by clicking a link (ie. GET request) """ - PATTERNS = client_path_patterns("/pushers/remove$") + PATTERNS = client_patterns("/pushers/remove$", v1=True) SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>" def __init__(self, hs): diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 255a85c588..e8f672c4ba 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -28,37 +28,45 @@ from synapse.api.errors import AuthError, Codes, SynapseError from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 from synapse.http.servlet import ( + RestServlet, assert_params_in_dict, parse_integer, parse_json_object_from_request, parse_string, ) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID -from .base import ClientV1RestServlet, client_path_patterns - logger = logging.getLogger(__name__) -class RoomCreateRestServlet(ClientV1RestServlet): +class TransactionRestServlet(RestServlet): + def __init__(self, hs): + super(TransactionRestServlet, self).__init__() + self.txns = HttpTransactionCache(hs) + + +class RoomCreateRestServlet(TransactionRestServlet): # No PATTERN; we have custom dispatch rules here def __init__(self, hs): super(RoomCreateRestServlet, self).__init__(hs) self._room_creation_handler = hs.get_room_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) # define CORS for all of /rooms in RoomCreateRestServlet for simplicity http_server.register_paths("OPTIONS", - client_path_patterns("/rooms(?:/.*)?$"), + client_patterns("/rooms(?:/.*)?$", v1=True), self.on_OPTIONS) # define CORS for /createRoom[/txnid] http_server.register_paths("OPTIONS", - client_path_patterns("/createRoom(?:/.*)?$"), + client_patterns("/createRoom(?:/.*)?$", v1=True), self.on_OPTIONS) def on_PUT(self, request, txn_id): @@ -85,13 +93,14 @@ class RoomCreateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(ClientV1RestServlet): +class RoomStateEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomStateEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() def register(self, http_server): # /room/$roomid/state/$eventtype @@ -102,16 +111,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet): "(?P<event_type>[^/]*)/(?P<state_key>[^/]*)$") http_server.register_paths("GET", - client_path_patterns(state_key), + client_patterns(state_key, v1=True), self.on_GET) http_server.register_paths("PUT", - client_path_patterns(state_key), + client_patterns(state_key, v1=True), self.on_PUT) http_server.register_paths("GET", - client_path_patterns(no_state_key), + client_patterns(no_state_key, v1=True), self.on_GET_no_state_key) http_server.register_paths("PUT", - client_path_patterns(no_state_key), + client_patterns(no_state_key, v1=True), self.on_PUT_no_state_key) def on_GET_no_state_key(self, request, room_id, event_type): @@ -185,11 +194,12 @@ class RoomStateEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for generic events + feedback -class RoomSendEventRestServlet(ClientV1RestServlet): +class RoomSendEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomSendEventRestServlet, self).__init__(hs) self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): # /rooms/$roomid/send/$event_type[/$txn_id] @@ -229,10 +239,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet): # TODO: Needs unit testing for room ID + alias joins -class JoinRoomAliasServlet(ClientV1RestServlet): +class JoinRoomAliasServlet(TransactionRestServlet): def __init__(self, hs): super(JoinRoomAliasServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): # /join/$room_identifier[/$txn_id] @@ -291,8 +302,13 @@ class JoinRoomAliasServlet(ClientV1RestServlet): # TODO: Needs unit testing -class PublicRoomListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/publicRooms$") +class PublicRoomListRestServlet(TransactionRestServlet): + PATTERNS = client_patterns("/publicRooms$", v1=True) + + def __init__(self, hs): + super(PublicRoomListRestServlet, self).__init__(hs) + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -382,12 +398,13 @@ class PublicRoomListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomMemberListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/members$") +class RoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/members$", v1=True) def __init__(self, hs): - super(RoomMemberListRestServlet, self).__init__(hs) + super(RoomMemberListRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -436,12 +453,13 @@ class RoomMemberListRestServlet(ClientV1RestServlet): # deprecated in favour of /members?membership=join? # except it does custom AS logic and has a simpler return format -class JoinedRoomMemberListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$") +class JoinedRoomMemberListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/joined_members$", v1=True) def __init__(self, hs): - super(JoinedRoomMemberListRestServlet, self).__init__(hs) + super(JoinedRoomMemberListRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -457,12 +475,13 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet): # TODO: Needs better unit testing -class RoomMessageListRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/messages$") +class RoomMessageListRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/messages$", v1=True) def __init__(self, hs): - super(RoomMessageListRestServlet, self).__init__(hs) + super(RoomMessageListRestServlet, self).__init__() self.pagination_handler = hs.get_pagination_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -475,6 +494,8 @@ class RoomMessageListRestServlet(ClientV1RestServlet): if filter_bytes: filter_json = urlparse.unquote(filter_bytes.decode("UTF-8")) event_filter = Filter(json.loads(filter_json)) + if event_filter.filter_json.get("event_format", "client") == "federation": + as_client_event = False else: event_filter = None msgs = yield self.pagination_handler.get_messages( @@ -489,12 +510,13 @@ class RoomMessageListRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomStateRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/state$") +class RoomStateRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/state$", v1=True) def __init__(self, hs): - super(RoomStateRestServlet, self).__init__(hs) + super(RoomStateRestServlet, self).__init__() self.message_handler = hs.get_message_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -509,12 +531,13 @@ class RoomStateRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomInitialSyncRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$") +class RoomInitialSyncRestServlet(RestServlet): + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/initialSync$", v1=True) def __init__(self, hs): - super(RoomInitialSyncRestServlet, self).__init__(hs) + super(RoomInitialSyncRestServlet, self).__init__() self.initial_sync_handler = hs.get_initial_sync_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id): @@ -528,16 +551,17 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): defer.returnValue((200, content)) -class RoomEventServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$" +class RoomEventServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomEventServlet, self).__init__(hs) + super(RoomEventServlet, self).__init__() self.clock = hs.get_clock() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -552,16 +576,17 @@ class RoomEventServlet(ClientV1RestServlet): defer.returnValue((404, "Event not found.")) -class RoomEventContextServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$" +class RoomEventContextServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomEventContextServlet, self).__init__(hs) + super(RoomEventContextServlet, self).__init__() self.clock = hs.get_clock() self.room_context_handler = hs.get_room_context_handler() self._event_serializer = hs.get_event_client_serializer() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request, room_id, event_id): @@ -607,10 +632,11 @@ class RoomEventContextServlet(ClientV1RestServlet): defer.returnValue((200, results)) -class RoomForgetRestServlet(ClientV1RestServlet): +class RoomForgetRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomForgetRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget") @@ -637,11 +663,12 @@ class RoomForgetRestServlet(ClientV1RestServlet): # TODO: Needs unit testing -class RoomMembershipRestServlet(ClientV1RestServlet): +class RoomMembershipRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomMembershipRestServlet, self).__init__(hs) self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() def register(self, http_server): # /rooms/$roomid/[invite|join|leave] @@ -720,11 +747,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet): ) -class RoomRedactEventRestServlet(ClientV1RestServlet): +class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs): super(RoomRedactEventRestServlet, self).__init__(hs) self.handlers = hs.get_handlers() self.event_creation_handler = hs.get_event_creation_handler() + self.auth = hs.get_auth() def register(self, http_server): PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)") @@ -755,15 +783,16 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): ) -class RoomTypingRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$" +class RoomTypingRestServlet(RestServlet): + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$", v1=True ) def __init__(self, hs): - super(RoomTypingRestServlet, self).__init__(hs) + super(RoomTypingRestServlet, self).__init__() self.presence_handler = hs.get_presence_handler() self.typing_handler = hs.get_typing_handler() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): @@ -796,14 +825,13 @@ class RoomTypingRestServlet(ClientV1RestServlet): defer.returnValue((200, {})) -class SearchRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns( - "/search$" - ) +class SearchRestServlet(RestServlet): + PATTERNS = client_patterns("/search$", v1=True) def __init__(self, hs): - super(SearchRestServlet, self).__init__(hs) + super(SearchRestServlet, self).__init__() self.handlers = hs.get_handlers() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_POST(self, request): @@ -821,12 +849,13 @@ class SearchRestServlet(ClientV1RestServlet): defer.returnValue((200, results)) -class JoinedRoomsRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/joined_rooms$") +class JoinedRoomsRestServlet(RestServlet): + PATTERNS = client_patterns("/joined_rooms$", v1=True) def __init__(self, hs): - super(JoinedRoomsRestServlet, self).__init__(hs) + super(JoinedRoomsRestServlet, self).__init__() self.store = hs.get_datastore() + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): @@ -851,18 +880,18 @@ def register_txn_path(servlet, regex_string, http_server, with_get=False): """ http_server.register_paths( "POST", - client_path_patterns(regex_string + "$"), + client_patterns(regex_string + "$", v1=True), servlet.on_POST ) http_server.register_paths( "PUT", - client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), + client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), servlet.on_PUT ) if with_get: http_server.register_paths( "GET", - client_path_patterns(regex_string + "/(?P<txn_id>[^/]*)$"), + client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), servlet.on_GET ) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 53da905eea..6381049210 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -19,11 +19,17 @@ import hmac from twisted.internet import defer -from .base import ClientV1RestServlet, client_path_patterns +from synapse.http.servlet import RestServlet +from synapse.rest.client.v2_alpha._base import client_patterns -class VoipRestServlet(ClientV1RestServlet): - PATTERNS = client_path_patterns("/voip/turnServer$") +class VoipRestServlet(RestServlet): + PATTERNS = client_patterns("/voip/turnServer$", v1=True) + + def __init__(self, hs): + super(VoipRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() @defer.inlineCallbacks def on_GET(self, request): diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index 24ac26bf03..5236d5d566 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -26,8 +26,7 @@ from synapse.api.urls import CLIENT_API_PREFIX logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,), - unstable=True): +def client_patterns(path_regex, releases=(0,), unstable=True, v1=False): """Creates a regex compiled client path with the correct client path prefix. @@ -41,6 +40,9 @@ def client_v2_patterns(path_regex, releases=(0,), if unstable: unstable_prefix = CLIENT_API_PREFIX + "/unstable" patterns.append(re.compile("^" + unstable_prefix + path_regex)) + if v1: + v1_prefix = CLIENT_API_PREFIX + "/api/v1" + patterns.append(re.compile("^" + v1_prefix + path_regex)) for release in releases: new_prefix = CLIENT_API_PREFIX + "/r%d" % (release,) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index ee069179f0..e4c63b69b9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -15,43 +15,74 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import re from six.moves import http_client +import jinja2 + from twisted.internet import defer from synapse.api.constants import LoginType -from synapse.api.errors import Codes, SynapseError +from synapse.api.errors import Codes, SynapseError, ThreepidValidationError +from synapse.http.server import finish_request from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, + parse_string, ) from synapse.util.msisdn import phone_number_to_msisdn +from synapse.util.stringutils import random_string from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + PATTERNS = client_patterns("/account/password/email/requestToken$") def __init__(self, hs): super(EmailPasswordRequestTokenRestServlet, self).__init__() self.hs = hs + self.datastore = hs.get_datastore() + self.config = hs.config self.identity_handler = hs.get_handlers().identity_handler + if self.config.email_password_reset_behaviour == "local": + from synapse.push.mailer import Mailer, load_jinja2_templates + templates = load_jinja2_templates( + config=hs.config, + template_html_name=hs.config.email_password_reset_template_html, + template_text_name=hs.config.email_password_reset_template_text, + ) + self.mailer = Mailer( + hs=self.hs, + app_name=self.config.email_app_name, + template_html=templates[0], + template_text=templates[1], + ) + @defer.inlineCallbacks def on_POST(self, request): + if self.config.email_password_reset_behaviour == "off": + raise SynapseError(400, "Password resets have been disabled on this server") + body = parse_json_object_from_request(request) assert_params_in_dict(body, [ - 'id_server', 'client_secret', 'email', 'send_attempt' + 'client_secret', 'email', 'send_attempt' ]) - if not check_3pid_allowed(self.hs, "email", body['email']): + # Extract params from body + client_secret = body["client_secret"] + email = body["email"] + send_attempt = body["send_attempt"] + next_link = body.get("next_link") # Optional param + + if not check_3pid_allowed(self.hs, "email", email): raise SynapseError( 403, "Your email domain is not authorized on this server", @@ -59,18 +90,103 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + 'email', email, ) if existingUid is None: raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - ret = yield self.identity_handler.requestEmailToken(**body) + if self.config.email_password_reset_behaviour == "remote": + if 'id_server' not in body: + raise SynapseError(400, "Missing 'id_server' param in body") + + # Have the identity server handle the password reset flow + ret = yield self.identity_handler.requestEmailToken( + body["id_server"], email, client_secret, send_attempt, next_link, + ) + else: + # Send password reset emails from Synapse + sid = yield self.send_password_reset( + email, client_secret, send_attempt, next_link, + ) + + # Wrap the session id in a JSON object + ret = {"sid": sid} + defer.returnValue((200, ret)) + @defer.inlineCallbacks + def send_password_reset( + self, + email, + client_secret, + send_attempt, + next_link=None, + ): + """Send a password reset email + + Args: + email (str): The user's email address + client_secret (str): The provided client secret + send_attempt (int): Which send attempt this is + + Returns: + The new session_id upon success + + Raises: + SynapseError is an error occurred when sending the email + """ + # Check that this email/client_secret/send_attempt combo is new or + # greater than what we've seen previously + session = yield self.datastore.get_threepid_validation_session( + "email", client_secret, address=email, validated=False, + ) + + # Check to see if a session already exists and that it is not yet + # marked as validated + if session and session.get("validated_at") is None: + session_id = session['session_id'] + last_send_attempt = session['last_send_attempt'] + + # Check that the send_attempt is higher than previous attempts + if send_attempt <= last_send_attempt: + # If not, just return a success without sending an email + defer.returnValue(session_id) + else: + # An non-validated session does not exist yet. + # Generate a session id + session_id = random_string(16) + + # Generate a new validation token + token = random_string(32) + + # Send the mail with the link containing the token, client_secret + # and session_id + try: + yield self.mailer.send_password_reset_mail( + email, token, client_secret, session_id, + ) + except Exception: + logger.exception( + "Error sending a password reset email to %s", email, + ) + raise SynapseError( + 500, "An error was encountered when sending the password reset email" + ) + + token_expires = (self.hs.clock.time_msec() + + self.config.email_validation_token_lifetime) + + yield self.datastore.start_or_continue_validation_session( + "email", email, session_id, client_secret, send_attempt, + next_link, token, token_expires, + ) + + defer.returnValue(session_id) + class MsisdnPasswordRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password/msisdn/requestToken$") + PATTERNS = client_patterns("/account/password/msisdn/requestToken$") def __init__(self, hs): super(MsisdnPasswordRequestTokenRestServlet, self).__init__() @@ -80,6 +196,9 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): @defer.inlineCallbacks def on_POST(self, request): + if not self.config.email_password_reset_behaviour == "off": + raise SynapseError(400, "Password resets have been disabled on this server") + body = parse_json_object_from_request(request) assert_params_in_dict(body, [ @@ -107,8 +226,120 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet): defer.returnValue((200, ret)) +class PasswordResetSubmitTokenServlet(RestServlet): + """Handles 3PID validation token submission""" + PATTERNS = [ + re.compile("^/_synapse/password_reset/(?P<medium>[^/]*)/submit_token/*$"), + ] + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(PasswordResetSubmitTokenServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.config = hs.config + self.clock = hs.get_clock() + self.datastore = hs.get_datastore() + + @defer.inlineCallbacks + def on_GET(self, request, medium): + if medium != "email": + raise SynapseError( + 400, + "This medium is currently not supported for password resets", + ) + + sid = parse_string(request, "sid") + client_secret = parse_string(request, "client_secret") + token = parse_string(request, "token") + + # Attempt to validate a 3PID sesssion + try: + # Mark the session as valid + next_link = yield self.datastore.validate_threepid_session( + sid, + client_secret, + token, + self.clock.time_msec(), + ) + + # Perform a 302 redirect if next_link is set + if next_link: + if next_link.startswith("file:///"): + logger.warn( + "Not redirecting to next_link as it is a local file: address" + ) + else: + request.setResponseCode(302) + request.setHeader("Location", next_link) + finish_request(request) + defer.returnValue(None) + + # Otherwise show the success template + html = self.config.email_password_reset_success_html_content + request.setResponseCode(200) + except ThreepidValidationError as e: + # Show a failure page with a reason + html = self.load_jinja2_template( + self.config.email_template_dir, + self.config.email_password_reset_failure_template, + template_vars={ + "failure_reason": e.msg, + } + ) + request.setResponseCode(e.code) + + request.write(html.encode('utf-8')) + finish_request(request) + defer.returnValue(None) + + def load_jinja2_template(self, template_dir, template_filename, template_vars): + """Loads a jinja2 template with variables to insert + + Args: + template_dir (str): The directory where templates are stored + template_filename (str): The name of the template in the template_dir + template_vars (Dict): Dictionary of keys in the template + alongside their values to insert + + Returns: + str containing the contents of the rendered template + """ + loader = jinja2.FileSystemLoader(template_dir) + env = jinja2.Environment(loader=loader) + + template = env.get_template(template_filename) + return template.render(**template_vars) + + @defer.inlineCallbacks + def on_POST(self, request, medium): + if medium != "email": + raise SynapseError( + 400, + "This medium is currently not supported for password resets", + ) + + body = parse_json_object_from_request(request) + assert_params_in_dict(body, [ + 'sid', 'client_secret', 'token', + ]) + + valid, _ = yield self.datastore.validate_threepid_validation_token( + body['sid'], + body['client_secret'], + body['token'], + self.clock.time_msec(), + ) + response_code = 200 if valid else 400 + + defer.returnValue((response_code, {"success": valid})) + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password$") + PATTERNS = client_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -144,6 +375,7 @@ class PasswordRestServlet(RestServlet): result, params, _ = yield self.auth_handler.check_auth( [[LoginType.EMAIL_IDENTITY], [LoginType.MSISDN]], body, self.hs.get_ip_from_request(request), + password_servlet=True, ) if LoginType.EMAIL_IDENTITY in result: @@ -180,7 +412,7 @@ class PasswordRestServlet(RestServlet): class DeactivateAccountRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/deactivate$") + PATTERNS = client_patterns("/account/deactivate$") def __init__(self, hs): super(DeactivateAccountRestServlet, self).__init__() @@ -228,7 +460,7 @@ class DeactivateAccountRestServlet(RestServlet): class EmailThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + PATTERNS = client_patterns("/account/3pid/email/requestToken$") def __init__(self, hs): self.hs = hs @@ -263,7 +495,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/msisdn/requestToken$") + PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") def __init__(self, hs): self.hs = hs @@ -300,7 +532,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid$") + PATTERNS = client_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() @@ -364,7 +596,7 @@ class ThreepidRestServlet(RestServlet): class ThreepidDeleteRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid/delete$") + PATTERNS = client_patterns("/account/3pid/delete$") def __init__(self, hs): super(ThreepidDeleteRestServlet, self).__init__() @@ -401,7 +633,7 @@ class ThreepidDeleteRestServlet(RestServlet): class WhoamiRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/whoami$") + PATTERNS = client_patterns("/account/whoami$") def __init__(self, hs): super(WhoamiRestServlet, self).__init__() @@ -417,6 +649,7 @@ class WhoamiRestServlet(RestServlet): def register_servlets(hs, http_server): EmailPasswordRequestTokenRestServlet(hs).register(http_server) MsisdnPasswordRequestTokenRestServlet(hs).register(http_server) + PasswordResetSubmitTokenServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) DeactivateAccountRestServlet(hs).register(http_server) EmailThreepidRequestTokenRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index f171b8d626..574a6298ce 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError, NotFoundError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -30,7 +30,7 @@ class AccountDataServlet(RestServlet): PUT /user/{user_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/account_data/{account_dataType} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/account_data/(?P<account_data_type>[^/]*)" ) @@ -79,7 +79,7 @@ class RoomAccountDataServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 GET /user/{user_id}/rooms/{room_id}/account_data/{account_dataType} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)" "/account_data/(?P<account_data_type>[^/]*)" diff --git a/synapse/rest/client/v2_alpha/account_validity.py b/synapse/rest/client/v2_alpha/account_validity.py index 9bc1e208ca..63bdc33564 100644 --- a/synapse/rest/client/v2_alpha/account_validity.py +++ b/synapse/rest/client/v2_alpha/account_validity.py @@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, SynapseError from synapse.http.server import finish_request from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class AccountValidityRenewServlet(RestServlet): - PATTERNS = client_v2_patterns("/account_validity/renew$") + PATTERNS = client_patterns("/account_validity/renew$") SUCCESS_HTML = b"<html><body>Your account has been successfully renewed.</body><html>" def __init__(self, hs): @@ -60,7 +60,7 @@ class AccountValidityRenewServlet(RestServlet): class AccountValiditySendMailServlet(RestServlet): - PATTERNS = client_v2_patterns("/account_validity/send_mail$") + PATTERNS = client_patterns("/account_validity/send_mail$") def __init__(self, hs): """ diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 4c380ab84d..8dfe5cba02 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -23,7 +23,7 @@ from synapse.api.urls import CLIENT_API_PREFIX from synapse.http.server import finish_request from synapse.http.servlet import RestServlet, parse_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -122,7 +122,7 @@ class AuthRestServlet(RestServlet): cannot be handled in the normal flow (with requests to the same endpoint). Current use is for web fallback auth. """ - PATTERNS = client_v2_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") + PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web") def __init__(self, hs): super(AuthRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/capabilities.py b/synapse/rest/client/v2_alpha/capabilities.py index a868d06098..fc7e2f4dd5 100644 --- a/synapse/rest/client/v2_alpha/capabilities.py +++ b/synapse/rest/client/v2_alpha/capabilities.py @@ -16,10 +16,10 @@ import logging from twisted.internet import defer -from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ logger = logging.getLogger(__name__) class CapabilitiesRestServlet(RestServlet): """End point to expose the capabilities of the server.""" - PATTERNS = client_v2_patterns("/capabilities$") + PATTERNS = client_patterns("/capabilities$") def __init__(self, hs): """ @@ -36,6 +36,7 @@ class CapabilitiesRestServlet(RestServlet): """ super(CapabilitiesRestServlet, self).__init__() self.hs = hs + self.config = hs.config self.auth = hs.get_auth() self.store = hs.get_datastore() @@ -48,7 +49,7 @@ class CapabilitiesRestServlet(RestServlet): response = { "capabilities": { "m.room_versions": { - "default": DEFAULT_ROOM_VERSION.identifier, + "default": self.config.default_room_version.identifier, "available": { v.identifier: v.disposition for v in KNOWN_ROOM_VERSIONS.values() diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py index 5a5be7c390..78665304a5 100644 --- a/synapse/rest/client/v2_alpha/devices.py +++ b/synapse/rest/client/v2_alpha/devices.py @@ -24,13 +24,13 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler logger = logging.getLogger(__name__) class DevicesRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/devices$") + PATTERNS = client_patterns("/devices$") def __init__(self, hs): """ @@ -56,7 +56,7 @@ class DeleteDevicesRestServlet(RestServlet): API for bulk deletion of devices. Accepts a JSON object with a devices key which lists the device_ids to delete. Requires user interactive auth. """ - PATTERNS = client_v2_patterns("/delete_devices") + PATTERNS = client_patterns("/delete_devices") def __init__(self, hs): super(DeleteDevicesRestServlet, self).__init__() @@ -95,7 +95,7 @@ class DeleteDevicesRestServlet(RestServlet): class DeviceRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$") + PATTERNS = client_patterns("/devices/(?P<device_id>[^/]*)$") def __init__(self, hs): """ diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py index ae86728879..65db48c3cc 100644 --- a/synapse/rest/client/v2_alpha/filter.py +++ b/synapse/rest/client/v2_alpha/filter.py @@ -21,13 +21,13 @@ from synapse.api.errors import AuthError, Codes, StoreError, SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID -from ._base import client_v2_patterns, set_timeline_upper_limit +from ._base import client_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) class GetFilterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter/(?P<filter_id>[^/]*)") def __init__(self, hs): super(GetFilterRestServlet, self).__init__() @@ -63,7 +63,7 @@ class GetFilterRestServlet(RestServlet): class CreateFilterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user/(?P<user_id>[^/]*)/filter") + PATTERNS = client_patterns("/user/(?P<user_id>[^/]*)/filter") def __init__(self, hs): super(CreateFilterRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/groups.py b/synapse/rest/client/v2_alpha/groups.py index 21e02c07c0..d082385ec7 100644 --- a/synapse/rest/client/v2_alpha/groups.py +++ b/synapse/rest/client/v2_alpha/groups.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import GroupID -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ logger = logging.getLogger(__name__) class GroupServlet(RestServlet): """Get the group profile """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/profile$") def __init__(self, hs): super(GroupServlet, self).__init__() @@ -65,7 +65,7 @@ class GroupServlet(RestServlet): class GroupSummaryServlet(RestServlet): """Get the full group summary """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/summary$") def __init__(self, hs): super(GroupSummaryServlet, self).__init__() @@ -93,7 +93,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/summary" "(/categories/(?P<category_id>[^/]+))?" "/rooms/(?P<room_id>[^/]*)$" @@ -137,7 +137,7 @@ class GroupSummaryRoomsCatServlet(RestServlet): class GroupCategoryServlet(RestServlet): """Get/add/update/delete a group category """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$" ) @@ -189,7 +189,7 @@ class GroupCategoryServlet(RestServlet): class GroupCategoriesServlet(RestServlet): """Get all group categories """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/categories/$" ) @@ -214,7 +214,7 @@ class GroupCategoriesServlet(RestServlet): class GroupRoleServlet(RestServlet): """Get/add/update/delete a group role """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$" ) @@ -266,7 +266,7 @@ class GroupRoleServlet(RestServlet): class GroupRolesServlet(RestServlet): """Get all group roles """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/roles/$" ) @@ -295,7 +295,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): - /groups/:group/summary/users/:room_id - /groups/:group/summary/roles/:role/users/:user_id """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/summary" "(/roles/(?P<role_id>[^/]+))?" "/users/(?P<user_id>[^/]*)$" @@ -339,7 +339,7 @@ class GroupSummaryUsersRoleServlet(RestServlet): class GroupRoomServlet(RestServlet): """Get all rooms in a group """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/rooms$") def __init__(self, hs): super(GroupRoomServlet, self).__init__() @@ -360,7 +360,7 @@ class GroupRoomServlet(RestServlet): class GroupUsersServlet(RestServlet): """Get all users in a group """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/users$") def __init__(self, hs): super(GroupUsersServlet, self).__init__() @@ -381,7 +381,7 @@ class GroupUsersServlet(RestServlet): class GroupInvitedUsersServlet(RestServlet): """Get users invited to a group """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/invited_users$") def __init__(self, hs): super(GroupInvitedUsersServlet, self).__init__() @@ -405,7 +405,7 @@ class GroupInvitedUsersServlet(RestServlet): class GroupSettingJoinPolicyServlet(RestServlet): """Set group join policy """ - PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") + PATTERNS = client_patterns("/groups/(?P<group_id>[^/]*)/settings/m.join_policy$") def __init__(self, hs): super(GroupSettingJoinPolicyServlet, self).__init__() @@ -431,7 +431,7 @@ class GroupSettingJoinPolicyServlet(RestServlet): class GroupCreateServlet(RestServlet): """Create a group """ - PATTERNS = client_v2_patterns("/create_group$") + PATTERNS = client_patterns("/create_group$") def __init__(self, hs): super(GroupCreateServlet, self).__init__() @@ -462,7 +462,7 @@ class GroupCreateServlet(RestServlet): class GroupAdminRoomsServlet(RestServlet): """Add a room to the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$" ) @@ -499,7 +499,7 @@ class GroupAdminRoomsServlet(RestServlet): class GroupAdminRoomsConfigServlet(RestServlet): """Update the config of a room in a group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)" "/config/(?P<config_key>[^/]*)$" ) @@ -526,7 +526,7 @@ class GroupAdminRoomsConfigServlet(RestServlet): class GroupAdminUsersInviteServlet(RestServlet): """Invite a user to the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$" ) @@ -555,7 +555,7 @@ class GroupAdminUsersInviteServlet(RestServlet): class GroupAdminUsersKickServlet(RestServlet): """Kick a user from the group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$" ) @@ -581,7 +581,7 @@ class GroupAdminUsersKickServlet(RestServlet): class GroupSelfLeaveServlet(RestServlet): """Leave a joined group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/self/leave$" ) @@ -607,7 +607,7 @@ class GroupSelfLeaveServlet(RestServlet): class GroupSelfJoinServlet(RestServlet): """Attempt to join a group, or knock """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/self/join$" ) @@ -633,7 +633,7 @@ class GroupSelfJoinServlet(RestServlet): class GroupSelfAcceptInviteServlet(RestServlet): """Accept a group invite """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/self/accept_invite$" ) @@ -659,7 +659,7 @@ class GroupSelfAcceptInviteServlet(RestServlet): class GroupSelfUpdatePublicityServlet(RestServlet): """Update whether we publicise a users membership of a group """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/groups/(?P<group_id>[^/]*)/self/update_publicity$" ) @@ -686,7 +686,7 @@ class GroupSelfUpdatePublicityServlet(RestServlet): class PublicisedGroupsForUserServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/publicised_groups/(?P<user_id>[^/]*)$" ) @@ -711,7 +711,7 @@ class PublicisedGroupsForUserServlet(RestServlet): class PublicisedGroupsForUsersServlet(RestServlet): """Get the list of groups a user is advertising """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/publicised_groups$" ) @@ -739,7 +739,7 @@ class PublicisedGroupsForUsersServlet(RestServlet): class GroupsForUserServlet(RestServlet): """Get all groups the logged in user is joined to """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/joined_groups$" ) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 8486086b51..4cbfbf5631 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -26,7 +26,7 @@ from synapse.http.servlet import ( ) from synapse.types import StreamToken -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") + PATTERNS = client_patterns("/keys/upload(/(?P<device_id>[^/]+))?$") def __init__(self, hs): """ @@ -130,7 +130,7 @@ class KeyQueryServlet(RestServlet): } } } } } } """ - PATTERNS = client_v2_patterns("/keys/query$") + PATTERNS = client_patterns("/keys/query$") def __init__(self, hs): """ @@ -159,7 +159,7 @@ class KeyChangesServlet(RestServlet): 200 OK { "changed": ["@foo:example.com"] } """ - PATTERNS = client_v2_patterns("/keys/changes$") + PATTERNS = client_patterns("/keys/changes$") def __init__(self, hs): """ @@ -209,7 +209,7 @@ class OneTimeKeyServlet(RestServlet): } } } } """ - PATTERNS = client_v2_patterns("/keys/claim$") + PATTERNS = client_patterns("/keys/claim$") def __init__(self, hs): super(OneTimeKeyServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/notifications.py b/synapse/rest/client/v2_alpha/notifications.py index 0a1eb0ae45..53e666989b 100644 --- a/synapse/rest/client/v2_alpha/notifications.py +++ b/synapse/rest/client/v2_alpha/notifications.py @@ -20,13 +20,13 @@ from twisted.internet import defer from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.servlet import RestServlet, parse_integer, parse_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class NotificationsServlet(RestServlet): - PATTERNS = client_v2_patterns("/notifications$") + PATTERNS = client_patterns("/notifications$") def __init__(self, hs): super(NotificationsServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/openid.py b/synapse/rest/client/v2_alpha/openid.py index 01c90aa2a3..bb927d9f9d 100644 --- a/synapse/rest/client/v2_alpha/openid.py +++ b/synapse/rest/client/v2_alpha/openid.py @@ -22,7 +22,7 @@ from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.util.stringutils import random_string -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class IdTokenServlet(RestServlet): "expires_in": 3600, } """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/openid/request_token" ) diff --git a/synapse/rest/client/v2_alpha/read_marker.py b/synapse/rest/client/v2_alpha/read_marker.py index a6e582a5ae..f4bd0d077f 100644 --- a/synapse/rest/client/v2_alpha/read_marker.py +++ b/synapse/rest/client/v2_alpha/read_marker.py @@ -19,13 +19,13 @@ from twisted.internet import defer from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReadMarkerRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") + PATTERNS = client_patterns("/rooms/(?P<room_id>[^/]*)/read_markers$") def __init__(self, hs): super(ReadMarkerRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py index de370cac45..fa12ac3e4d 100644 --- a/synapse/rest/client/v2_alpha/receipts.py +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -20,13 +20,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReceiptRestServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/rooms/(?P<room_id>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)" "/(?P<event_id>[^/]*)$" diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index fa0cedb8d4..79c085408b 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -43,7 +43,7 @@ from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.threepids import check_3pid_allowed -from ._base import client_v2_patterns, interactive_auth_handler +from ._base import client_patterns, interactive_auth_handler # We ought to be using hmac.compare_digest() but on older pythons it doesn't # exist. It's a _really minor_ security flaw to use plain string comparison @@ -60,7 +60,7 @@ logger = logging.getLogger(__name__) class EmailRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/email/requestToken$") + PATTERNS = client_patterns("/register/email/requestToken$") def __init__(self, hs): """ @@ -98,7 +98,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): class MsisdnRegisterRequestTokenRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/msisdn/requestToken$") + PATTERNS = client_patterns("/register/msisdn/requestToken$") def __init__(self, hs): """ @@ -142,7 +142,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): class UsernameAvailabilityRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register/available") + PATTERNS = client_patterns("/register/available") def __init__(self, hs): """ @@ -182,7 +182,7 @@ class UsernameAvailabilityRestServlet(RestServlet): class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register$") + PATTERNS = client_patterns("/register$") def __init__(self, hs): """ @@ -348,18 +348,22 @@ class RegisterRestServlet(RestServlet): if self.hs.config.enable_registration_captcha: # only support 3PIDless registration if no 3PIDs are required if not require_email and not require_msisdn: - flows.extend([[LoginType.RECAPTCHA]]) + # Also add a dummy flow here, otherwise if a client completes + # recaptcha first we'll assume they were going for this flow + # and complete the request, when they could have been trying to + # complete one of the flows with email/msisdn auth. + flows.extend([[LoginType.RECAPTCHA, LoginType.DUMMY]]) # only support the email-only flow if we don't require MSISDN 3PIDs if not require_msisdn: - flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.EMAIL_IDENTITY]]) if show_msisdn: # only support the MSISDN-only flow if we don't require email 3PIDs if not require_email: - flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]]) + flows.extend([[LoginType.RECAPTCHA, LoginType.MSISDN]]) # always let users provide both MSISDN & email flows.extend([ - [LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA], + [LoginType.RECAPTCHA, LoginType.MSISDN, LoginType.EMAIL_IDENTITY], ]) else: # only support 3PIDless registration if no 3PIDs are required @@ -382,7 +386,15 @@ class RegisterRestServlet(RestServlet): if self.hs.config.user_consent_at_registration: new_flows = [] for flow in flows: - flow.append(LoginType.TERMS) + inserted = False + # m.login.terms should go near the end but before msisdn or email auth + for i, stage in enumerate(flow): + if stage == LoginType.EMAIL_IDENTITY or stage == LoginType.MSISDN: + flow.insert(i, LoginType.TERMS) + inserted = True + break + if not inserted: + flow.append(LoginType.TERMS) flows.extend(new_flows) auth_result, params, session_id = yield self.auth_handler.check_auth( diff --git a/synapse/rest/client/v2_alpha/relations.py b/synapse/rest/client/v2_alpha/relations.py new file mode 100644 index 0000000000..f8f8742bdc --- /dev/null +++ b/synapse/rest/client/v2_alpha/relations.py @@ -0,0 +1,338 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This class implements the proposed relation APIs from MSC 1849. + +Since the MSC has not been approved all APIs here are unstable and may change at +any time to reflect changes in the MSC. +""" + +import logging + +from twisted.internet import defer + +from synapse.api.constants import EventTypes, RelationTypes +from synapse.api.errors import SynapseError +from synapse.http.servlet import ( + RestServlet, + parse_integer, + parse_json_object_from_request, + parse_string, +) +from synapse.rest.client.transactions import HttpTransactionCache +from synapse.storage.relations import AggregationPaginationToken, RelationPaginationToken + +from ._base import client_patterns + +logger = logging.getLogger(__name__) + + +class RelationSendServlet(RestServlet): + """Helper API for sending events that have relation data. + + Example API shape to send a 👍 reaction to a room: + + POST /rooms/!foo/send_relation/$bar/m.annotation/m.reaction?key=%F0%9F%91%8D + {} + + { + "event_id": "$foobar" + } + """ + + PATTERN = ( + "/rooms/(?P<room_id>[^/]*)/send_relation" + "/(?P<parent_id>[^/]*)/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)" + ) + + def __init__(self, hs): + super(RelationSendServlet, self).__init__() + self.auth = hs.get_auth() + self.event_creation_handler = hs.get_event_creation_handler() + self.txns = HttpTransactionCache(hs) + + def register(self, http_server): + http_server.register_paths( + "POST", + client_patterns(self.PATTERN + "$", releases=()), + self.on_PUT_or_POST, + ) + http_server.register_paths( + "PUT", + client_patterns(self.PATTERN + "/(?P<txn_id>[^/]*)$", releases=()), + self.on_PUT, + ) + + def on_PUT(self, request, *args, **kwargs): + return self.txns.fetch_or_execute_request( + request, self.on_PUT_or_POST, request, *args, **kwargs + ) + + @defer.inlineCallbacks + def on_PUT_or_POST( + self, request, room_id, parent_id, relation_type, event_type, txn_id=None + ): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + if event_type == EventTypes.Member: + # Add relations to a membership is meaningless, so we just deny it + # at the CS API rather than trying to handle it correctly. + raise SynapseError(400, "Cannot send member events with relations") + + content = parse_json_object_from_request(request) + + aggregation_key = parse_string(request, "key", encoding="utf-8") + + content["m.relates_to"] = { + "event_id": parent_id, + "key": aggregation_key, + "rel_type": relation_type, + } + + event_dict = { + "type": event_type, + "content": content, + "room_id": room_id, + "sender": requester.user.to_string(), + } + + event = yield self.event_creation_handler.create_and_send_nonmember_event( + requester, event_dict=event_dict, txn_id=txn_id + ) + + defer.returnValue((200, {"event_id": event.event_id})) + + +class RelationPaginationServlet(RestServlet): + """API to paginate relations on an event by topological ordering, optionally + filtered by relation type and event type. + """ + + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/relations/(?P<parent_id>[^/]*)" + "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +class RelationAggregationPaginationServlet(RestServlet): + """API to paginate aggregation groups of relations, e.g. paginate the + types and counts of the reactions on the events. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id} + + { + chunk: [ + { + "type": "m.reaction", + "key": "👍", + "count": 3 + } + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" + "(/(?P<relation_type>[^/]*)(/(?P<event_type>[^/]*))?)?$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type=None, event_type=None): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type not in (RelationTypes.ANNOTATION, None): + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = AggregationPaginationToken.from_string(from_token) + + if to_token: + to_token = AggregationPaginationToken.from_string(to_token) + + res = yield self.store.get_aggregation_groups_for_event( + event_id=parent_id, + event_type=event_type, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + defer.returnValue((200, res.to_dict())) + + +class RelationAggregationGroupPaginationServlet(RestServlet): + """API to paginate within an aggregation group of relations, e.g. paginate + all the 👍 reactions on an event. + + Example request and response: + + GET /rooms/{room_id}/aggregations/{parent_id}/m.annotation/m.reaction/👍 + + { + chunk: [ + { + "type": "m.reaction", + "content": { + "m.relates_to": { + "rel_type": "m.annotation", + "key": "👍" + } + } + }, + ... + ] + } + """ + + PATTERNS = client_patterns( + "/rooms/(?P<room_id>[^/]*)/aggregations/(?P<parent_id>[^/]*)" + "/(?P<relation_type>[^/]*)/(?P<event_type>[^/]*)/(?P<key>[^/]*)$", + releases=(), + ) + + def __init__(self, hs): + super(RelationAggregationGroupPaginationServlet, self).__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + self.event_handler = hs.get_event_handler() + + @defer.inlineCallbacks + def on_GET(self, request, room_id, parent_id, relation_type, event_type, key): + requester = yield self.auth.get_user_by_req(request, allow_guest=True) + + yield self.auth.check_in_room_or_world_readable( + room_id, requester.user.to_string() + ) + + # This checks that a) the event exists and b) the user is allowed to + # view it. + yield self.event_handler.get_event(requester.user, room_id, parent_id) + + if relation_type != RelationTypes.ANNOTATION: + raise SynapseError(400, "Relation type must be 'annotation'") + + limit = parse_integer(request, "limit", default=5) + from_token = parse_string(request, "from") + to_token = parse_string(request, "to") + + if from_token: + from_token = RelationPaginationToken.from_string(from_token) + + if to_token: + to_token = RelationPaginationToken.from_string(to_token) + + result = yield self.store.get_relations_for_event( + event_id=parent_id, + relation_type=relation_type, + event_type=event_type, + aggregation_key=key, + limit=limit, + from_token=from_token, + to_token=to_token, + ) + + events = yield self.store.get_events_as_list( + [c["event_id"] for c in result.chunk] + ) + + now = self.clock.time_msec() + events = yield self._event_serializer.serialize_events(events, now) + + return_value = result.to_dict() + return_value["chunk"] = events + + defer.returnValue((200, return_value)) + + +def register_servlets(hs, http_server): + RelationSendServlet(hs).register(http_server) + RelationPaginationServlet(hs).register(http_server) + RelationAggregationPaginationServlet(hs).register(http_server) + RelationAggregationGroupPaginationServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 95d2a71ec2..10198662a9 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -27,13 +27,13 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ReportEventRestServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/rooms/(?P<room_id>[^/]*)/report/(?P<event_id>[^/]*)$" ) diff --git a/synapse/rest/client/v2_alpha/room_keys.py b/synapse/rest/client/v2_alpha/room_keys.py index 220a0de30b..87779645f9 100644 --- a/synapse/rest/client/v2_alpha/room_keys.py +++ b/synapse/rest/client/v2_alpha/room_keys.py @@ -24,13 +24,13 @@ from synapse.http.servlet import ( parse_string, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class RoomKeysServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$" ) @@ -256,7 +256,7 @@ class RoomKeysServlet(RestServlet): class RoomKeysNewVersionServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/version$" ) @@ -314,7 +314,7 @@ class RoomKeysNewVersionServlet(RestServlet): class RoomKeysVersionServlet(RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/room_keys/version(/(?P<version>[^/]+))?$" ) diff --git a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py index 62b8de71fa..c621a90fba 100644 --- a/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py +++ b/synapse/rest/client/v2_alpha/room_upgrade_rest_servlet.py @@ -25,7 +25,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, ) -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -47,7 +47,7 @@ class RoomUpgradeRestServlet(RestServlet): Args: hs (synapse.server.HomeServer): """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( # /rooms/$roomid/upgrade "/rooms/(?P<room_id>[^/]*)/upgrade$", ) diff --git a/synapse/rest/client/v2_alpha/sendtodevice.py b/synapse/rest/client/v2_alpha/sendtodevice.py index 21e9cef2d0..120a713361 100644 --- a/synapse/rest/client/v2_alpha/sendtodevice.py +++ b/synapse/rest/client/v2_alpha/sendtodevice.py @@ -21,13 +21,13 @@ from synapse.http import servlet from synapse.http.servlet import parse_json_object_from_request from synapse.rest.client.transactions import HttpTransactionCache -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class SendToDeviceRestServlet(servlet.RestServlet): - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/sendToDevice/(?P<message_type>[^/]*)/(?P<txn_id>[^/]*)$", ) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index c701e534e7..148fc6c985 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -32,7 +32,7 @@ from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken -from ._base import client_v2_patterns, set_timeline_upper_limit +from ._base import client_patterns, set_timeline_upper_limit logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ class SyncRestServlet(RestServlet): } """ - PATTERNS = client_v2_patterns("/sync$") + PATTERNS = client_patterns("/sync$") ALLOWED_PRESENCE = set(["online", "offline", "unavailable"]) def __init__(self, hs): @@ -358,6 +358,9 @@ class SyncRestServlet(RestServlet): def serialize(events): return self._event_serializer.serialize_events( events, time_now=time_now, + # We don't bundle "live" events, as otherwise clients + # will end up double counting annotations. + bundle_aggregations=False, token_id=token_id, event_format=event_formatter, only_event_fields=only_fields, diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 4fea614e95..ebff7cff45 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -20,7 +20,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) @@ -29,7 +29,7 @@ class TagListServlet(RestServlet): """ GET /user/{user_id}/rooms/{room_id}/tags HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags" ) @@ -54,7 +54,7 @@ class TagServlet(RestServlet): PUT /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 DELETE /user/{user_id}/rooms/{room_id}/tags/{tag} HTTP/1.1 """ - PATTERNS = client_v2_patterns( + PATTERNS = client_patterns( "/user/(?P<user_id>[^/]*)/rooms/(?P<room_id>[^/]*)/tags/(?P<tag>[^/]*)" ) diff --git a/synapse/rest/client/v2_alpha/thirdparty.py b/synapse/rest/client/v2_alpha/thirdparty.py index b9b5d07677..e7a987466a 100644 --- a/synapse/rest/client/v2_alpha/thirdparty.py +++ b/synapse/rest/client/v2_alpha/thirdparty.py @@ -21,13 +21,13 @@ from twisted.internet import defer from synapse.api.constants import ThirdPartyEntityKind from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class ThirdPartyProtocolsServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocols") + PATTERNS = client_patterns("/thirdparty/protocols") def __init__(self, hs): super(ThirdPartyProtocolsServlet, self).__init__() @@ -44,7 +44,7 @@ class ThirdPartyProtocolsServlet(RestServlet): class ThirdPartyProtocolServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") + PATTERNS = client_patterns("/thirdparty/protocol/(?P<protocol>[^/]+)$") def __init__(self, hs): super(ThirdPartyProtocolServlet, self).__init__() @@ -66,7 +66,7 @@ class ThirdPartyProtocolServlet(RestServlet): class ThirdPartyUserServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") + PATTERNS = client_patterns("/thirdparty/user(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyUserServlet, self).__init__() @@ -89,7 +89,7 @@ class ThirdPartyUserServlet(RestServlet): class ThirdPartyLocationServlet(RestServlet): - PATTERNS = client_v2_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") + PATTERNS = client_patterns("/thirdparty/location(/(?P<protocol>[^/]+))?$") def __init__(self, hs): super(ThirdPartyLocationServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 6e76b9e9c2..6c366142e1 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -18,7 +18,7 @@ from twisted.internet import defer from synapse.api.errors import AuthError from synapse.http.servlet import RestServlet -from ._base import client_v2_patterns +from ._base import client_patterns class TokenRefreshRestServlet(RestServlet): @@ -26,7 +26,7 @@ class TokenRefreshRestServlet(RestServlet): Exchanges refresh tokens for a pair of an access token and a new refresh token. """ - PATTERNS = client_v2_patterns("/tokenrefresh") + PATTERNS = client_patterns("/tokenrefresh") def __init__(self, hs): super(TokenRefreshRestServlet, self).__init__() diff --git a/synapse/rest/client/v2_alpha/user_directory.py b/synapse/rest/client/v2_alpha/user_directory.py index 36b02de37f..69e4efc47a 100644 --- a/synapse/rest/client/v2_alpha/user_directory.py +++ b/synapse/rest/client/v2_alpha/user_directory.py @@ -20,13 +20,13 @@ from twisted.internet import defer from synapse.api.errors import SynapseError from synapse.http.servlet import RestServlet, parse_json_object_from_request -from ._base import client_v2_patterns +from ._base import client_patterns logger = logging.getLogger(__name__) class UserDirectorySearchRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/user_directory/search$") + PATTERNS = client_patterns("/user_directory/search$") def __init__(self, hs): """ diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 27e7cbf3cc..babbf6a23c 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -39,6 +39,7 @@ class VersionsRestServlet(RestServlet): "r0.2.0", "r0.3.0", "r0.4.0", + "r0.5.0", ], # as per MSC1497: "unstable_features": { diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index eb8782aa6e..8a730bbc35 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -20,7 +20,7 @@ from twisted.web.resource import Resource from twisted.web.server import NOT_DONE_YET from synapse.api.errors import Codes, SynapseError -from synapse.crypto.keyring import KeyLookupError +from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler from synapse.http.servlet import parse_integer, parse_json_object_from_request @@ -89,7 +89,7 @@ class RemoteKey(Resource): isLeaf = True def __init__(self, hs): - self.keyring = hs.get_keyring() + self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() self.federation_domain_whitelist = hs.config.federation_domain_whitelist @@ -215,15 +215,7 @@ class RemoteKey(Resource): json_results.add(bytes(result["key_json"])) if cache_misses and query_remote_on_cache_miss: - for server_name, key_ids in cache_misses.items(): - try: - yield self.keyring.get_server_verify_key_v2_direct( - server_name, key_ids - ) - except KeyLookupError as e: - logger.info("Failed to fetch key: %s", e) - except Exception: - logger.exception("Failed to get key for %r", server_name) + yield self.fetcher.get_keys(cache_misses) yield self.query_keys( request, query, query_remote_on_cache_miss=False ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index bdffa97805..8569677355 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -444,6 +444,9 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = thumbnailer.transpose() + if t_method == "crop": t_byte_source = thumbnailer.crop(t_width, t_height, t_type) elif t_method == "scale": @@ -578,6 +581,12 @@ class MediaRepository(object): ) return + if thumbnailer.transpose_method is not None: + m_width, m_height = yield logcontext.defer_to_thread( + self.hs.get_reactor(), + thumbnailer.transpose + ) + # We deduplicate the thumbnail sizes by ignoring the cropped versions if # they have the same dimensions of a scaled one. thumbnails = {} diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 5305e9175f..35a750923b 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -56,8 +56,8 @@ class ThumbnailResource(Resource): def _async_render_GET(self, request): set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) - width = parse_integer(request, "width") - height = parse_integer(request, "height") + width = parse_integer(request, "width", required=True) + height = parse_integer(request, "height", required=True) method = parse_string(request, "method", "scale") m_type = parse_string(request, "type", "image/png") diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index a4b26c2587..3efd0d80fc 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -20,6 +20,17 @@ import PIL.Image as Image logger = logging.getLogger(__name__) +EXIF_ORIENTATION_TAG = 0x0112 +EXIF_TRANSPOSE_MAPPINGS = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90 +} + class Thumbnailer(object): @@ -31,6 +42,30 @@ class Thumbnailer(object): def __init__(self, input_path): self.image = Image.open(input_path) self.width, self.height = self.image.size + self.transpose_method = None + try: + # We don't use ImageOps.exif_transpose since it crashes with big EXIF + image_exif = self.image._getexif() + if image_exif is not None: + image_orientation = image_exif.get(EXIF_ORIENTATION_TAG) + self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation) + except Exception as e: + # A lot of parsing errors can happen when parsing EXIF + logger.info("Error parsing image EXIF information: %s", e) + + def transpose(self): + """Transpose the image using its EXIF Orientation tag + + Returns: + Tuple[int, int]: (width, height) containing the new image size in pixels. + """ + if self.transpose_method is not None: + self.image = self.image.transpose(self.transpose_method) + self.width, self.height = self.image.size + self.transpose_method = None + # We don't need EXIF any more + self.image.info["exif"] = None + return self.image.size def aspect(self, max_width, max_height): """Calculate the largest size that preserves aspect ratio which |