diff options
Diffstat (limited to 'synapse/rest')
-rw-r--r-- | synapse/rest/__init__.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/v1/admin.py | 77 | ||||
-rw-r--r-- | synapse/rest/client/v1/base.py | 4 | ||||
-rw-r--r-- | synapse/rest/client/v1/login.py | 83 | ||||
-rw-r--r-- | synapse/rest/client/v1/register.py | 33 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 11 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/_base.py | 13 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/account.py | 114 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/devices.py | 100 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/keys.py | 102 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/register.py | 343 | ||||
-rw-r--r-- | synapse/rest/client/v2_alpha/tokenrefresh.py | 10 | ||||
-rw-r--r-- | synapse/rest/client/versions.py | 6 | ||||
-rw-r--r-- | synapse/rest/media/v0/content_repository.py | 112 | ||||
-rw-r--r-- | synapse/rest/media/v1/filepath.py | 6 | ||||
-rw-r--r-- | synapse/rest/media/v1/media_repository.py | 91 | ||||
-rw-r--r-- | synapse/rest/media/v1/preview_url_resource.py | 87 |
17 files changed, 829 insertions, 365 deletions
diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index 8b223e032b..14227f1cdb 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -46,6 +46,7 @@ from synapse.rest.client.v2_alpha import ( account_data, report_event, openid, + devices, ) from synapse.http.server import JsonResource @@ -90,3 +91,4 @@ class ClientRestResource(JsonResource): account_data.register_servlets(hs, client_resource) report_event.register_servlets(hs, client_resource) openid.register_servlets(hs, client_resource) + devices.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index aa05b3f023..b0cb31a448 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -46,5 +46,82 @@ class WhoisRestServlet(ClientV1RestServlet): defer.returnValue((200, ret)) +class PurgeMediaCacheRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/purge_media_cache") + + def __init__(self, hs): + self.media_repository = hs.get_media_repository() + super(PurgeMediaCacheRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + before_ts = request.args.get("before_ts", None) + if not before_ts: + raise SynapseError(400, "Missing 'before_ts' arg") + + logger.info("before_ts: %r", before_ts[0]) + + try: + before_ts = int(before_ts[0]) + except Exception: + raise SynapseError(400, "Invalid 'before_ts' arg") + + ret = yield self.media_repository.delete_old_remote_media(before_ts) + + defer.returnValue((200, ret)) + + +class PurgeHistoryRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns( + "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" + ) + + @defer.inlineCallbacks + def on_POST(self, request, room_id, event_id): + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + yield self.handlers.message_handler.purge_history(room_id, event_id) + + defer.returnValue((200, {})) + + +class DeactivateAccountRestServlet(ClientV1RestServlet): + PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") + + def __init__(self, hs): + self.store = hs.get_datastore() + super(DeactivateAccountRestServlet, self).__init__(hs) + + @defer.inlineCallbacks + def on_POST(self, request, target_user_id): + UserID.from_string(target_user_id) + requester = yield self.auth.get_user_by_req(request) + is_admin = yield self.auth.is_server_admin(requester.user) + + if not is_admin: + raise AuthError(403, "You are not a server admin") + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(target_user_id) + yield self.store.user_delete_threepids(target_user_id) + yield self.store.user_set_password_hash(target_user_id, None) + + defer.returnValue((200, {})) + + def register_servlets(hs, http_server): WhoisRestServlet(hs).register(http_server) + PurgeMediaCacheRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + PurgeHistoryRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/base.py b/synapse/rest/client/v1/base.py index 1c020b7e2c..96b49b01f2 100644 --- a/synapse/rest/client/v1/base.py +++ b/synapse/rest/client/v1/base.py @@ -52,6 +52,10 @@ class ClientV1RestServlet(RestServlet): """ def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ self.hs = hs self.handlers = hs.get_handlers() self.builder_factory = hs.get_event_builder_factory() diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 8df9d10efa..92fcae674a 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -59,6 +59,7 @@ class LoginRestServlet(ClientV1RestServlet): self.servername = hs.config.server_name self.http_client = hs.get_simple_http_client() self.auth_handler = self.hs.get_auth_handler() + self.device_handler = self.hs.get_device_handler() def on_GET(self, request): flows = [] @@ -145,15 +146,23 @@ class LoginRestServlet(ClientV1RestServlet): ).to_string() auth_handler = self.auth_handler - user_id, access_token, refresh_token = yield auth_handler.login_with_password( + user_id = yield auth_handler.validate_password_login( user_id=user_id, - password=login_submission["password"]) - + password=login_submission["password"], + ) + device_id = yield self._register_device(user_id, login_submission) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) + ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -165,14 +174,19 @@ class LoginRestServlet(ClientV1RestServlet): user_id = ( yield auth_handler.validate_short_term_login_token_and_get_user_id(token) ) - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + device_id = yield self._register_device(user_id, login_submission) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { "user_id": user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, + "device_id": device_id, } defer.returnValue((200, result)) @@ -196,13 +210,15 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, # may have changed "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, @@ -245,18 +261,27 @@ class LoginRestServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if user_exists: - user_id, access_token, refresh_token = ( - yield auth_handler.get_login_tuple_for_user_id(user_id) + registered_user_id = yield auth_handler.check_user_exists(user_id) + if registered_user_id: + device_id = yield self._register_device( + registered_user_id, login_submission + ) + access_token, refresh_token = ( + yield auth_handler.get_login_tuple_for_user_id( + registered_user_id, device_id, + login_submission.get("initial_device_display_name") + ) ) result = { - "user_id": user_id, # may have changed + "user_id": registered_user_id, "access_token": access_token, "refresh_token": refresh_token, "home_server": self.hs.hostname, } else: + # TODO: we should probably check that the register isn't going + # to fonx/change our user_id before registering the device + device_id = yield self._register_device(user_id, login_submission) user_id, access_token = ( yield self.handlers.registration_handler.register(localpart=user) ) @@ -295,6 +320,26 @@ class LoginRestServlet(ClientV1RestServlet): return (user, attributes) + def _register_device(self, user_id, login_submission): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) login_submission: dictionary supplied to /login call, from + which we pull device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + device_id = login_submission.get("device_id") + initial_display_name = login_submission.get( + "initial_device_display_name") + return self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + class SAML2RestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/login/saml2", releases=()) @@ -414,13 +459,13 @@ class CasTicketServlet(ClientV1RestServlet): user_id = UserID.create(user, self.hs.hostname).to_string() auth_handler = self.auth_handler - user_exists = yield auth_handler.does_user_exist(user_id) - if not user_exists: - user_id, _ = ( + registered_user_id = yield auth_handler.check_user_exists(user_id) + if not registered_user_id: + registered_user_id, _ = ( yield self.handlers.registration_handler.register(localpart=user) ) - login_token = auth_handler.generate_short_term_login_token(user_id) + login_token = auth_handler.generate_short_term_login_token(registered_user_id) redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, login_token) request.redirect(redirect_url) diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py index e3f4fbb0bb..2383b9df86 100644 --- a/synapse/rest/client/v1/register.py +++ b/synapse/rest/client/v1/register.py @@ -52,6 +52,10 @@ class RegisterRestServlet(ClientV1RestServlet): PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__(hs) # sessions are stored as: # self.sessions = { @@ -60,6 +64,7 @@ class RegisterRestServlet(ClientV1RestServlet): # TODO: persistent storage self.sessions = {} self.enable_registration = hs.config.enable_registration + self.auth_handler = hs.get_auth_handler() def on_GET(self, request): if self.hs.config.enable_registration_captcha: @@ -299,9 +304,10 @@ class RegisterRestServlet(ClientV1RestServlet): user_localpart = register_json["user"].encode("utf-8") handler = self.handlers.registration_handler - (user_id, token) = yield handler.appservice_register( + user_id = yield handler.appservice_register( user_localpart, as_token ) + token = yield self.auth_handler.issue_access_token(user_id) self._remove_session(session) defer.returnValue({ "user_id": user_id, @@ -324,6 +330,14 @@ class RegisterRestServlet(ClientV1RestServlet): raise SynapseError(400, "Shared secret registration is not enabled") user = register_json["user"].encode("utf-8") + password = register_json["password"].encode("utf-8") + admin = register_json.get("admin", None) + + # Its important to check as we use null bytes as HMAC field separators + if "\x00" in user: + raise SynapseError(400, "Invalid user") + if "\x00" in password: + raise SynapseError(400, "Invalid password") # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface @@ -331,17 +345,21 @@ class RegisterRestServlet(ClientV1RestServlet): want_mac = hmac.new( key=self.hs.config.registration_shared_secret, - msg=user, digestmod=sha1, - ).hexdigest() - - password = register_json["password"].encode("utf-8") + ) + want_mac.update(user) + want_mac.update("\x00") + want_mac.update(password) + want_mac.update("\x00") + want_mac.update("admin" if admin else "notadmin") + want_mac = want_mac.hexdigest() if compare_digest(want_mac, got_mac): handler = self.handlers.registration_handler user_id, token = yield handler.register( localpart=user, password=password, + admin=bool(admin), ) self._remove_session(session) defer.returnValue({ @@ -410,12 +428,15 @@ class CreateUserRestServlet(ClientV1RestServlet): raise SynapseError(400, "Failed to parse 'duration_seconds'") if duration_seconds > self.direct_user_creation_max_duration: duration_seconds = self.direct_user_creation_max_duration + password_hash = user_json["password_hash"].encode("utf-8") \ + if user_json.get("password_hash") else None handler = self.handlers.registration_handler user_id, token = yield handler.get_or_create_user( localpart=localpart, displayname=displayname, - duration_seconds=duration_seconds + duration_in_ms=(duration_seconds * 1000), + password_hash=password_hash ) defer.returnValue({ diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 86fbe2747d..866a1e9120 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -20,12 +20,14 @@ from .base import ClientV1RestServlet, client_path_patterns from synapse.api.errors import SynapseError, Codes, AuthError from synapse.streams.config import PaginationConfig from synapse.api.constants import EventTypes, Membership +from synapse.api.filtering import Filter from synapse.types import UserID, RoomID, RoomAlias from synapse.events.utils import serialize_event from synapse.http.servlet import parse_json_object_from_request import logging import urllib +import ujson as json logger = logging.getLogger(__name__) @@ -327,12 +329,19 @@ class RoomMessageListRestServlet(ClientV1RestServlet): request, default_limit=10, ) as_client_event = "raw" not in request.args + filter_bytes = request.args.get("filter", None) + if filter_bytes: + filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") + event_filter = Filter(json.loads(filter_json)) + else: + event_filter = None handler = self.handlers.message_handler msgs = yield handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, - as_client_event=as_client_event + as_client_event=as_client_event, + event_filter=event_filter, ) defer.returnValue((200, msgs)) diff --git a/synapse/rest/client/v2_alpha/_base.py b/synapse/rest/client/v2_alpha/_base.py index b6faa2b0e6..20e765f48f 100644 --- a/synapse/rest/client/v2_alpha/_base.py +++ b/synapse/rest/client/v2_alpha/_base.py @@ -25,7 +25,9 @@ import logging logger = logging.getLogger(__name__) -def client_v2_patterns(path_regex, releases=(0,)): +def client_v2_patterns(path_regex, releases=(0,), + v2_alpha=True, + unstable=True): """Creates a regex compiled client path with the correct client path prefix. @@ -35,9 +37,12 @@ def client_v2_patterns(path_regex, releases=(0,)): Returns: SRE_Pattern """ - patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] - unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") - patterns.append(re.compile("^" + unstable_prefix + path_regex)) + patterns = [] + if v2_alpha: + patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) + if unstable: + unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") + patterns.append(re.compile("^" + unstable_prefix + path_regex)) for release in releases: new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) patterns.append(re.compile("^" + new_prefix + path_regex)) diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 9a84873a5f..eb49ad62e9 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -28,8 +28,40 @@ import logging logger = logging.getLogger(__name__) +class PasswordRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/password/email/requestToken$") + + def __init__(self, hs): + super(PasswordRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is None: + raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class PasswordRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/password") + PATTERNS = client_v2_patterns("/account/password$") def __init__(self, hs): super(PasswordRestServlet, self).__init__() @@ -89,8 +121,83 @@ class PasswordRestServlet(RestServlet): return 200, {} +class DeactivateAccountRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/deactivate$") + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() + super(DeactivateAccountRestServlet, self).__init__() + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + authed, result, params, _ = yield self.auth_handler.check_auth([ + [LoginType.PASSWORD], + ], body, self.hs.get_ip_from_request(request)) + + if not authed: + defer.returnValue((401, result)) + + user_id = None + requester = None + + if LoginType.PASSWORD in result: + # if using password, they should also be logged in + requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() + if user_id != result[LoginType.PASSWORD]: + raise LoginError(400, "", Codes.UNKNOWN) + else: + logger.error("Auth succeeded but no known type!", result.keys()) + raise SynapseError(500, "", Codes.UNKNOWN) + + # FIXME: Theoretically there is a race here wherein user resets password + # using threepid. + yield self.store.user_delete_access_tokens(user_id) + yield self.store.user_delete_threepids(user_id) + yield self.store.user_set_password_hash(user_id, None) + + defer.returnValue((200, {})) + + +class ThreepidRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") + + def __init__(self, hs): + self.hs = hs + super(ThreepidRequestTokenRestServlet, self).__init__() + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if absent: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class ThreepidRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/account/3pid") + PATTERNS = client_v2_patterns("/account/3pid$") def __init__(self, hs): super(ThreepidRestServlet, self).__init__() @@ -157,5 +264,8 @@ class ThreepidRestServlet(RestServlet): def register_servlets(hs, http_server): + PasswordRequestTokenRestServlet(hs).register(http_server) PasswordRestServlet(hs).register(http_server) + DeactivateAccountRestServlet(hs).register(http_server) + ThreepidRequestTokenRestServlet(hs).register(http_server) ThreepidRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/devices.py b/synapse/rest/client/v2_alpha/devices.py new file mode 100644 index 0000000000..8fbd3d3dfc --- /dev/null +++ b/synapse/rest/client/v2_alpha/devices.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.internet import defer + +from synapse.http import servlet +from ._base import client_v2_patterns + +logger = logging.getLogger(__name__) + + +class DevicesRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DevicesRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request): + requester = yield self.auth.get_user_by_req(request) + devices = yield self.device_handler.get_devices_by_user( + requester.user.to_string() + ) + defer.returnValue((200, {"devices": devices})) + + +class DeviceRestServlet(servlet.RestServlet): + PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", + releases=[], v2_alpha=False) + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(DeviceRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() + + @defer.inlineCallbacks + def on_GET(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + device = yield self.device_handler.get_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, device)) + + @defer.inlineCallbacks + def on_DELETE(self, request, device_id): + # XXX: it's not completely obvious we want to expose this endpoint. + # It allows the client to delete access tokens, which feels like a + # thing which merits extra auth. But if we want to do the interactive- + # auth dance, we should really make it possible to delete more than one + # device at a time. + requester = yield self.auth.get_user_by_req(request) + yield self.device_handler.delete_device( + requester.user.to_string(), + device_id, + ) + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_PUT(self, request, device_id): + requester = yield self.auth.get_user_by_req(request) + + body = servlet.parse_json_object_from_request(request) + yield self.device_handler.update_device( + requester.user.to_string(), + device_id, + body + ) + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + DevicesRestServlet(hs).register(http_server) + DeviceRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py index 89ab39491c..c5ff16adf3 100644 --- a/synapse/rest/client/v2_alpha/keys.py +++ b/synapse/rest/client/v2_alpha/keys.py @@ -13,24 +13,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +import simplejson as json +from canonicaljson import encode_canonical_json from twisted.internet import defer +import synapse.api.errors +import synapse.server +import synapse.types from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.types import UserID - -from canonicaljson import encode_canonical_json - from ._base import client_v2_patterns -import logging -import simplejson as json - logger = logging.getLogger(__name__) class KeyUploadServlet(RestServlet): """ - POST /keys/upload/<device_id> HTTP/1.1 + POST /keys/upload HTTP/1.1 Content-Type: application/json { @@ -53,23 +54,45 @@ class KeyUploadServlet(RestServlet): }, } """ - PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) + PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", + releases=()) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(KeyUploadServlet, self).__init__() self.store = hs.get_datastore() self.clock = hs.get_clock() self.auth = hs.get_auth() + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request, device_id): requester = yield self.auth.get_user_by_req(request) + user_id = requester.user.to_string() - # TODO: Check that the device_id matches that in the authentication - # or derive the device_id from the authentication instead. body = parse_json_object_from_request(request) + if device_id is not None: + # passing the device_id here is deprecated; however, we allow it + # for now for compatibility with older clients. + if (requester.device_id is not None and + device_id != requester.device_id): + logger.warning("Client uploading keys for a different device " + "(logged in as %s, uploading for %s)", + requester.device_id, device_id) + else: + device_id = requester.device_id + + if device_id is None: + raise synapse.api.errors.SynapseError( + 400, + "To upload keys, you must pass device_id when authenticating" + ) + time_now = self.clock.time_msec() # TODO: Validate the JSON to make sure it has the right keys. @@ -102,13 +125,12 @@ class KeyUploadServlet(RestServlet): user_id, device_id, time_now, key_list ) - result = yield self.store.count_e2e_one_time_keys(user_id, device_id) - defer.returnValue((200, {"one_time_key_counts": result})) - - @defer.inlineCallbacks - def on_GET(self, request, device_id): - requester = yield self.auth.get_user_by_req(request) - user_id = requester.user.to_string() + # the device should have been registered already, but it may have been + # deleted due to a race with a DELETE request. Or we may be using an + # old access_token without an associated device_id. Either way, we + # need to double-check the device is registered to avoid ending up with + # keys without a corresponding device. + self.device_handler.check_device_registered(user_id, device_id) result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue((200, {"one_time_key_counts": result})) @@ -162,17 +184,19 @@ class KeyQueryServlet(RestServlet): ) def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): + """ super(KeyQueryServlet, self).__init__() - self.store = hs.get_datastore() self.auth = hs.get_auth() - self.federation = hs.get_replication_layer() - self.is_mine = hs.is_mine + self.e2e_keys_handler = hs.get_e2e_keys_handler() @defer.inlineCallbacks def on_POST(self, request, user_id, device_id): yield self.auth.get_user_by_req(request) body = parse_json_object_from_request(request) - result = yield self.handle_request(body) + result = yield self.e2e_keys_handler.query_devices(body) defer.returnValue(result) @defer.inlineCallbacks @@ -181,45 +205,11 @@ class KeyQueryServlet(RestServlet): auth_user_id = requester.user.to_string() user_id = user_id if user_id else auth_user_id device_ids = [device_id] if device_id else [] - result = yield self.handle_request( + result = yield self.e2e_keys_handler.query_devices( {"device_keys": {user_id: device_ids}} ) defer.returnValue(result) - @defer.inlineCallbacks - def handle_request(self, body): - local_query = [] - remote_queries = {} - for user_id, device_ids in body.get("device_keys", {}).items(): - user = UserID.from_string(user_id) - if self.is_mine(user): - if not device_ids: - local_query.append((user_id, None)) - else: - for device_id in device_ids: - local_query.append((user_id, device_id)) - else: - remote_queries.setdefault(user.domain, {})[user_id] = list( - device_ids - ) - results = yield self.store.get_e2e_device_keys(local_query) - - json_result = {} - for user_id, device_keys in results.items(): - for device_id, json_bytes in device_keys.items(): - json_result.setdefault(user_id, {})[device_id] = json.loads( - json_bytes - ) - - for destination, device_keys in remote_queries.items(): - remote_result = yield self.federation.query_client_keys( - destination, {"device_keys": device_keys} - ) - for user_id, keys in remote_result["device_keys"].items(): - if user_id in device_keys: - json_result[user_id] = keys - defer.returnValue((200, {"device_keys": json_result})) - class OneTimeKeyServlet(RestServlet): """ diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 2088c316d1..943f5676a3 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -41,17 +41,59 @@ else: logger = logging.getLogger(__name__) +class RegisterRequestTokenRestServlet(RestServlet): + PATTERNS = client_v2_patterns("/register/email/requestToken$") + + def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ + super(RegisterRequestTokenRestServlet, self).__init__() + self.hs = hs + self.identity_handler = hs.get_handlers().identity_handler + + @defer.inlineCallbacks + def on_POST(self, request): + body = parse_json_object_from_request(request) + + required = ['id_server', 'client_secret', 'email', 'send_attempt'] + absent = [] + for k in required: + if k not in body: + absent.append(k) + + if len(absent) > 0: + raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + + existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( + 'email', body['email'] + ) + + if existingUid is not None: + raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + + ret = yield self.identity_handler.requestEmailToken(**body) + defer.returnValue((200, ret)) + + class RegisterRestServlet(RestServlet): - PATTERNS = client_v2_patterns("/register") + PATTERNS = client_v2_patterns("/register$") def __init__(self, hs): + """ + Args: + hs (synapse.server.HomeServer): server + """ super(RegisterRestServlet, self).__init__() + self.hs = hs self.auth = hs.get_auth() self.store = hs.get_datastore() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_handlers().registration_handler self.identity_handler = hs.get_handlers().identity_handler + self.device_handler = hs.get_device_handler() @defer.inlineCallbacks def on_POST(self, request): @@ -70,10 +112,6 @@ class RegisterRestServlet(RestServlet): "Do not understand membership kind: %s" % (kind,) ) - if '/register/email/requestToken' in request.path: - ret = yield self.onEmailTokenRequest(request) - defer.returnValue(ret) - body = parse_json_object_from_request(request) # we do basic sanity checks here because the auth layer will store these @@ -104,11 +142,12 @@ class RegisterRestServlet(RestServlet): # Set the desired user according to the AS API (which uses the # 'user' key not 'username'). Since this is a new addition, we'll # fallback to 'username' if they gave one. - if isinstance(body.get("user"), basestring): - desired_username = body["user"] - result = yield self._do_appservice_registration( - desired_username, request.args["access_token"][0] - ) + desired_username = body.get("user", desired_username) + + if isinstance(desired_username, basestring): + result = yield self._do_appservice_registration( + desired_username, request.args["access_token"][0], body + ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -117,7 +156,7 @@ class RegisterRestServlet(RestServlet): # FIXME: Should we really be determining if this is shared secret # auth based purely on the 'mac' key? result = yield self._do_shared_secret_registration( - desired_username, desired_password, body["mac"] + desired_username, desired_password, body ) defer.returnValue((200, result)) # we throw for non 200 responses return @@ -157,12 +196,12 @@ class RegisterRestServlet(RestServlet): [LoginType.EMAIL_IDENTITY] ] - authed, result, params, session_id = yield self.auth_handler.check_auth( + authed, auth_result, params, session_id = yield self.auth_handler.check_auth( flows, body, self.hs.get_ip_from_request(request) ) if not authed: - defer.returnValue((401, result)) + defer.returnValue((401, auth_result)) return if registered_user_id is not None: @@ -170,106 +209,58 @@ class RegisterRestServlet(RestServlet): "Already registered user ID %r for this session", registered_user_id ) - access_token = yield self.auth_handler.issue_access_token(registered_user_id) - refresh_token = yield self.auth_handler.issue_refresh_token( - registered_user_id + # don't re-register the email address + add_email = False + else: + # NB: This may be from the auth handler and NOT from the POST + if 'password' not in params: + raise SynapseError(400, "Missing password.", + Codes.MISSING_PARAM) + + desired_username = params.get("username", None) + new_password = params.get("password", None) + guest_access_token = params.get("guest_access_token", None) + + (registered_user_id, _) = yield self.registration_handler.register( + localpart=desired_username, + password=new_password, + guest_access_token=guest_access_token, + generate_token=False, ) - defer.returnValue((200, { - "user_id": registered_user_id, - "access_token": access_token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - })) - - # NB: This may be from the auth handler and NOT from the POST - if 'password' not in params: - raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) - - desired_username = params.get("username", None) - new_password = params.get("password", None) - guest_access_token = params.get("guest_access_token", None) - - (user_id, token) = yield self.registration_handler.register( - localpart=desired_username, - password=new_password, - guest_access_token=guest_access_token, - ) - # remember that we've now registered that user account, and with what - # user ID (since the user may not have specified) - self.auth_handler.set_session_data( - session_id, "registered_user_id", user_id + # remember that we've now registered that user account, and with + # what user ID (since the user may not have specified) + self.auth_handler.set_session_data( + session_id, "registered_user_id", registered_user_id + ) + + add_email = True + + return_dict = yield self._create_registration_details( + registered_user_id, params ) - if result and LoginType.EMAIL_IDENTITY in result: - threepid = result[LoginType.EMAIL_IDENTITY] - - for reqd in ['medium', 'address', 'validated_at']: - if reqd not in threepid: - logger.info("Can't add incomplete 3pid") - else: - yield self.auth_handler.add_threepid( - user_id, - threepid['medium'], - threepid['address'], - threepid['validated_at'], - ) - - # And we add an email pusher for them by default, but only - # if email notifications are enabled (so people don't start - # getting mail spam where they weren't before if email - # notifs are set up on a home server) - if ( - self.hs.config.email_enable_notifs and - self.hs.config.email_notif_for_new_users - ): - # Pull the ID of the access token back out of the db - # It would really make more sense for this to be passed - # up when the access token is saved, but that's quite an - # invasive change I'd rather do separately. - user_tuple = yield self.store.get_user_by_access_token( - token - ) - - yield self.hs.get_pusherpool().add_pusher( - user_id=user_id, - access_token=user_tuple["token_id"], - kind="email", - app_id="m.email", - app_display_name="Email Notifications", - device_display_name=threepid["address"], - pushkey=threepid["address"], - lang=None, # We don't know a user's language here - data={}, - ) - - if 'bind_email' in params and params['bind_email']: - logger.info("bind_email specified: binding") - - emailThreepid = result[LoginType.EMAIL_IDENTITY] - threepid_creds = emailThreepid['threepid_creds'] - logger.debug("Binding emails %s to %s" % ( - emailThreepid, user_id - )) - yield self.identity_handler.bind_threepid(threepid_creds, user_id) - else: - logger.info("bind_email not specified: not binding email") - - result = yield self._create_registration_details(user_id, token) - defer.returnValue((200, result)) + if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result: + threepid = auth_result[LoginType.EMAIL_IDENTITY] + yield self._register_email_threepid( + registered_user_id, threepid, return_dict["access_token"], + params.get("bind_email") + ) + + defer.returnValue((200, return_dict)) def on_OPTIONS(self, _): return 200, {} @defer.inlineCallbacks - def _do_appservice_registration(self, username, as_token): - (user_id, token) = yield self.registration_handler.appservice_register( + def _do_appservice_registration(self, username, as_token, body): + user_id = yield self.registration_handler.appservice_register( username, as_token ) - defer.returnValue((yield self._create_registration_details(user_id, token))) + defer.returnValue((yield self._create_registration_details(user_id, body))) @defer.inlineCallbacks - def _do_shared_secret_registration(self, username, password, mac): + def _do_shared_secret_registration(self, username, password, body): if not self.hs.config.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") @@ -277,7 +268,7 @@ class RegisterRestServlet(RestServlet): # str() because otherwise hmac complains that 'unicode' does not # have the buffer interface - got_mac = str(mac) + got_mac = str(body["mac"]) want_mac = hmac.new( key=self.hs.config.registration_shared_secret, @@ -290,43 +281,132 @@ class RegisterRestServlet(RestServlet): 403, "HMAC incorrect", ) - (user_id, token) = yield self.registration_handler.register( - localpart=username, password=password + (user_id, _) = yield self.registration_handler.register( + localpart=username, password=password, generate_token=False, ) - defer.returnValue((yield self._create_registration_details(user_id, token))) - @defer.inlineCallbacks - def _create_registration_details(self, user_id, token): - refresh_token = yield self.auth_handler.issue_refresh_token(user_id) - defer.returnValue({ - "user_id": user_id, - "access_token": token, - "home_server": self.hs.hostname, - "refresh_token": refresh_token, - }) + result = yield self._create_registration_details(user_id, body) + defer.returnValue(result) @defer.inlineCallbacks - def onEmailTokenRequest(self, request): - body = parse_json_object_from_request(request) + def _register_email_threepid(self, user_id, threepid, token, bind_email): + """Add an email address as a 3pid identifier + + Also adds an email pusher for the email address, if configured in the + HS config + + Also optionally binds emails to the given user_id on the identity server + + Args: + user_id (str): id of user + threepid (object): m.login.email.identity auth response + token (str): access_token for the user + bind_email (bool): true if the client requested the email to be + bound at the identity server + Returns: + defer.Deferred: + """ + reqd = ('medium', 'address', 'validated_at') + if any(x not in threepid for x in reqd): + logger.info("Can't add incomplete 3pid") + defer.returnValue() + + yield self.auth_handler.add_threepid( + user_id, + threepid['medium'], + threepid['address'], + threepid['validated_at'], + ) - required = ['id_server', 'client_secret', 'email', 'send_attempt'] - absent = [] - for k in required: - if k not in body: - absent.append(k) + # And we add an email pusher for them by default, but only + # if email notifications are enabled (so people don't start + # getting mail spam where they weren't before if email + # notifs are set up on a home server) + if (self.hs.config.email_enable_notifs and + self.hs.config.email_notif_for_new_users): + # Pull the ID of the access token back out of the db + # It would really make more sense for this to be passed + # up when the access token is saved, but that's quite an + # invasive change I'd rather do separately. + user_tuple = yield self.store.get_user_by_access_token( + token + ) + token_id = user_tuple["token_id"] + + yield self.hs.get_pusherpool().add_pusher( + user_id=user_id, + access_token=token_id, + kind="email", + app_id="m.email", + app_display_name="Email Notifications", + device_display_name=threepid["address"], + pushkey=threepid["address"], + lang=None, # We don't know a user's language here + data={}, + ) - if len(absent) > 0: - raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) + if bind_email: + logger.info("bind_email specified: binding") + logger.debug("Binding emails %s to %s" % ( + threepid, user_id + )) + yield self.identity_handler.bind_threepid( + threepid['threepid_creds'], user_id + ) + else: + logger.info("bind_email not specified: not binding email") - existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( - 'email', body['email'] + @defer.inlineCallbacks + def _create_registration_details(self, user_id, params): + """Complete registration of newly-registered user + + Allocates device_id if one was not given; also creates access_token + and refresh_token. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (object) dictionary for response from /register + """ + device_id = yield self._register_device(user_id, params) + + access_token, refresh_token = ( + yield self.auth_handler.get_login_tuple_for_user_id( + user_id, device_id=device_id, + initial_display_name=params.get("initial_device_display_name") + ) ) - if existingUid is not None: - raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) + defer.returnValue({ + "user_id": user_id, + "access_token": access_token, + "home_server": self.hs.hostname, + "refresh_token": refresh_token, + "device_id": device_id, + }) - ret = yield self.identity_handler.requestEmailToken(**body) - defer.returnValue((200, ret)) + def _register_device(self, user_id, params): + """Register a device for a user. + + This is called after the user's credentials have been validated, but + before the access token has been issued. + + Args: + (str) user_id: full canonical @user:id + (object) params: registration parameters, from which we pull + device_id and initial_device_name + Returns: + defer.Deferred: (str) device_id + """ + # register the user's device + device_id = params.get("device_id") + initial_display_name = params.get("initial_device_display_name") + device_id = self.device_handler.check_device_registered( + user_id, device_id, initial_display_name + ) + return device_id @defer.inlineCallbacks def _do_guest_registration(self): @@ -336,7 +416,11 @@ class RegisterRestServlet(RestServlet): generate_token=False, make_guest=True ) - access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) + access_token = self.auth_handler.generate_access_token( + user_id, ["guest = true"] + ) + # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok + # so long as we don't return a refresh_token here. defer.returnValue((200, { "user_id": user_id, "access_token": access_token, @@ -345,4 +429,5 @@ class RegisterRestServlet(RestServlet): def register_servlets(hs, http_server): + RegisterRequestTokenRestServlet(hs).register(http_server) RegisterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/tokenrefresh.py b/synapse/rest/client/v2_alpha/tokenrefresh.py index 8270e8787f..0d312c91d4 100644 --- a/synapse/rest/client/v2_alpha/tokenrefresh.py +++ b/synapse/rest/client/v2_alpha/tokenrefresh.py @@ -39,9 +39,13 @@ class TokenRefreshRestServlet(RestServlet): try: old_refresh_token = body["refresh_token"] auth_handler = self.hs.get_auth_handler() - (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( - old_refresh_token, auth_handler.generate_refresh_token) - new_access_token = yield auth_handler.issue_access_token(user_id) + refresh_result = yield self.store.exchange_refresh_token( + old_refresh_token, auth_handler.generate_refresh_token + ) + (user_id, new_refresh_token, device_id) = refresh_result + new_access_token = yield auth_handler.issue_access_token( + user_id, device_id + ) defer.returnValue((200, { "access_token": new_access_token, "refresh_token": new_refresh_token, diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index ca5468c402..e984ea47db 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -26,7 +26,11 @@ class VersionsRestServlet(RestServlet): def on_GET(self, request): return (200, { - "versions": ["r0.0.1"] + "versions": [ + "r0.0.1", + "r0.1.0", + "r0.2.0", + ] }) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index d9fc045fc6..956bd5da75 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -15,14 +15,12 @@ from synapse.http.server import respond_with_json_bytes, finish_request -from synapse.util.stringutils import random_string from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException, Codes, cs_error + Codes, cs_error ) from twisted.protocols.basic import FileSender from twisted.web import server, resource -from twisted.internet import defer import base64 import simplejson as json @@ -50,64 +48,10 @@ class ContentRepoResource(resource.Resource): """ isLeaf = True - def __init__(self, hs, directory, auth, external_addr): + def __init__(self, hs, directory): resource.Resource.__init__(self) self.hs = hs self.directory = directory - self.auth = auth - self.external_addr = external_addr.rstrip('/') - self.max_upload_size = hs.config.max_upload_size - - if not os.path.isdir(self.directory): - os.mkdir(self.directory) - logger.info("ContentRepoResource : Created %s directory.", - self.directory) - - @defer.inlineCallbacks - def map_request_to_name(self, request): - # auth the user - requester = yield self.auth.get_user_by_req(request) - - # namespace all file uploads on the user - prefix = base64.urlsafe_b64encode( - requester.user.to_string() - ).replace('=', '') - - # use a random string for the main portion - main_part = random_string(24) - - # suffix with a file extension if we can make one. This is nice to - # provide a hint to clients on the file information. We will also reuse - # this info to spit back the content type to the client. - suffix = "" - if request.requestHeaders.hasHeader("Content-Type"): - content_type = request.requestHeaders.getRawHeaders( - "Content-Type")[0] - suffix = "." + base64.urlsafe_b64encode(content_type) - if (content_type.split("/")[0].lower() in - ["image", "video", "audio"]): - file_ext = content_type.split("/")[-1] - # be a little paranoid and only allow a-z - file_ext = re.sub("[^a-z]", "", file_ext) - suffix += "." + file_ext - - file_name = prefix + main_part + suffix - file_path = os.path.join(self.directory, file_name) - logger.info("User %s is uploading a file to path %s", - request.user.user_id.to_string(), - file_path) - - # keep trying to make a non-clashing file, with a sensible max attempts - attempts = 0 - while os.path.exists(file_path): - main_part = random_string(24) - file_name = prefix + main_part + suffix - file_path = os.path.join(self.directory, file_name) - attempts += 1 - if attempts > 25: # really? Really? - raise SynapseError(500, "Unable to create file.") - - defer.returnValue(file_path) def render_GET(self, request): # no auth here on purpose, to allow anyone to view, even across home @@ -155,58 +99,6 @@ class ContentRepoResource(resource.Resource): return server.NOT_DONE_YET - def render_POST(self, request): - self._async_render(request) - return server.NOT_DONE_YET - def render_OPTIONS(self, request): respond_with_json_bytes(request, 200, {}, send_cors=True) return server.NOT_DONE_YET - - @defer.inlineCallbacks - def _async_render(self, request): - try: - # TODO: The checks here are a bit late. The content will have - # already been uploaded to a tmp file at this point - content_length = request.getHeader("Content-Length") - if content_length is None: - raise SynapseError( - msg="Request must specify a Content-Length", code=400 - ) - if int(content_length) > self.max_upload_size: - raise SynapseError( - msg="Upload request body is too large", - code=413, - ) - - fname = yield self.map_request_to_name(request) - - # TODO I have a suspicious feeling this is just going to block - with open(fname, "wb") as f: - f.write(request.content.read()) - - # FIXME (erikj): These should use constants. - file_name = os.path.basename(fname) - # FIXME: we can't assume what the repo's public mounted path is - # ...plus self-signed SSL won't work to remote clients anyway - # ...and we can't assume that it's SSL anyway, as we might want to - # serve it via the non-SSL listener... - url = "%s/_matrix/content/%s" % ( - self.external_addr, file_name - ) - - respond_with_json_bytes(request, 200, - json.dumps({"content_token": url}), - send_cors=True) - - except CodeMessageException as e: - logger.exception(e) - respond_with_json_bytes(request, e.code, - json.dumps(cs_exception(e))) - except Exception as e: - logger.error("Failed to store file: %s" % e) - respond_with_json_bytes( - request, - 500, - json.dumps({"error": "Internal server error"}), - send_cors=True) diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 422ab86fb3..0137458f71 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -65,3 +65,9 @@ class MediaFilePaths(object): file_id[0:2], file_id[2:4], file_id[4:], file_name ) + + def remote_media_thumbnail_dir(self, server_name, file_id): + return os.path.join( + self.base_path, "remote_thumbnail", server_name, + file_id[0:2], file_id[2:4], file_id[4:], + ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 2468c3ac42..692e078419 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -30,11 +30,13 @@ from synapse.api.errors import SynapseError from twisted.internet import defer, threads -from synapse.util.async import ObservableDeferred +from synapse.util.async import Linearizer from synapse.util.stringutils import is_ascii from synapse.util.logcontext import preserve_context_over_fn import os +import errno +import shutil import cgi import logging @@ -43,8 +45,11 @@ import urlparse logger = logging.getLogger(__name__) +UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 + + class MediaRepository(object): - def __init__(self, hs, filepaths): + def __init__(self, hs): self.auth = hs.get_auth() self.client = MatrixFederationHttpClient(hs) self.clock = hs.get_clock() @@ -52,11 +57,28 @@ class MediaRepository(object): self.store = hs.get_datastore() self.max_upload_size = hs.config.max_upload_size self.max_image_pixels = hs.config.max_image_pixels - self.filepaths = filepaths - self.downloads = {} + self.filepaths = MediaFilePaths(hs.config.media_store_path) self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.thumbnail_requirements = hs.config.thumbnail_requirements + self.remote_media_linearizer = Linearizer() + + self.recently_accessed_remotes = set() + + self.clock.looping_call( + self._update_recently_accessed_remotes, + UPDATE_RECENTLY_ACCESSED_REMOTES_TS + ) + + @defer.inlineCallbacks + def _update_recently_accessed_remotes(self): + media = self.recently_accessed_remotes + self.recently_accessed_remotes = set() + + yield self.store.update_cached_last_access_time( + media, self.clock.time_msec() + ) + @staticmethod def _makedirs(filepath): dirname = os.path.dirname(filepath) @@ -93,22 +115,12 @@ class MediaRepository(object): defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) + @defer.inlineCallbacks def get_remote_media(self, server_name, media_id): key = (server_name, media_id) - download = self.downloads.get(key) - if download is None: - download = self._get_remote_media_impl(server_name, media_id) - download = ObservableDeferred( - download, - consumeErrors=True - ) - self.downloads[key] = download - - @download.addBoth - def callback(media_info): - del self.downloads[key] - return media_info - return download.observe() + with (yield self.remote_media_linearizer.queue(key)): + media_info = yield self._get_remote_media_impl(server_name, media_id) + defer.returnValue(media_info) @defer.inlineCallbacks def _get_remote_media_impl(self, server_name, media_id): @@ -119,6 +131,11 @@ class MediaRepository(object): media_info = yield self._download_remote_file( server_name, media_id ) + else: + self.recently_accessed_remotes.add((server_name, media_id)) + yield self.store.update_cached_last_access_time( + [(server_name, media_id)], self.clock.time_msec() + ) defer.returnValue(media_info) @defer.inlineCallbacks @@ -416,6 +433,41 @@ class MediaRepository(object): "height": m_height, }) + @defer.inlineCallbacks + def delete_old_remote_media(self, before_ts): + old_media = yield self.store.get_remote_media_before(before_ts) + + deleted = 0 + + for media in old_media: + origin = media["media_origin"] + media_id = media["media_id"] + file_id = media["filesystem_id"] + key = (origin, media_id) + + logger.info("Deleting: %r", key) + + with (yield self.remote_media_linearizer.queue(key)): + full_path = self.filepaths.remote_media_filepath(origin, file_id) + try: + os.remove(full_path) + except OSError as e: + logger.warn("Failed to remove file: %r", full_path) + if e.errno == errno.ENOENT: + pass + else: + continue + + thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( + origin, file_id + ) + shutil.rmtree(thumbnail_dir, ignore_errors=True) + + yield self.store.delete_remote_media(origin, media_id) + deleted += 1 + + defer.returnValue({"deleted": deleted}) + class MediaRepositoryResource(Resource): """File uploading and downloading. @@ -464,9 +516,8 @@ class MediaRepositoryResource(Resource): def __init__(self, hs): Resource.__init__(self) - filepaths = MediaFilePaths(hs.config.media_store_path) - media_repo = MediaRepository(hs, filepaths) + media_repo = hs.get_media_repository() self.putChild("upload", UploadResource(hs, media_repo)) self.putChild("download", DownloadResource(hs, media_repo)) diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 74c64f1371..bdd0e60c5b 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -29,6 +29,8 @@ from synapse.http.server import ( from synapse.util.async import ObservableDeferred from synapse.util.stringutils import is_ascii +from copy import deepcopy + import os import re import fnmatch @@ -329,20 +331,24 @@ class PreviewUrlResource(Resource): # ...or if they are within a <script/> or <style/> tag. # This is a very very very coarse approximation to a plain text # render of the page. - text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | " - "ancestor::aside | ancestor::footer | " - "ancestor::script | ancestor::style)]" + - "[ancestor::body]") - text = '' - for text_node in text_nodes: - if len(text) < 500: - text += text_node + ' ' - else: - break - text = re.sub(r'[\t ]+', ' ', text) - text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) - text = text.strip()[:500] - og['og:description'] = text if text else None + + # We don't just use XPATH here as that is slow on some machines. + + # We clone `tree` as we modify it. + cloned_tree = deepcopy(tree.find("body")) + + TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",) + for el in cloned_tree.iter(TAGS_TO_REMOVE): + el.getparent().remove(el) + + # Split all the text nodes into paragraphs (by splitting on new + # lines) + text_nodes = ( + re.sub(r'\s+', '\n', el.text).strip() + for el in cloned_tree.iter() + if el.text and isinstance(el.tag, basestring) # Removes comments + ) + og['og:description'] = summarize_paragraphs(text_nodes) # TODO: delete the url downloads to stop diskfilling, # as we only ever cared about its OG @@ -450,3 +456,56 @@ class PreviewUrlResource(Resource): content_type.startswith("application/xhtml") ): return True + + +def summarize_paragraphs(text_nodes, min_size=200, max_size=500): + # Try to get a summary of between 200 and 500 words, respecting + # first paragraph and then word boundaries. + # TODO: Respect sentences? + + description = '' + + # Keep adding paragraphs until we get to the MIN_SIZE. + for text_node in text_nodes: + if len(description) < min_size: + text_node = re.sub(r'[\t \r\n]+', ' ', text_node) + description += text_node + '\n\n' + else: + break + + description = description.strip() + description = re.sub(r'[\t ]+', ' ', description) + description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) + + # If the concatenation of paragraphs to get above MIN_SIZE + # took us over MAX_SIZE, then we need to truncate mid paragraph + if len(description) > max_size: + new_desc = "" + + # This splits the paragraph into words, but keeping the + # (preceeding) whitespace intact so we can easily concat + # words back together. + for match in re.finditer("\s*\S+", description): + word = match.group() + + # Keep adding words while the total length is less than + # MAX_SIZE. + if len(word) + len(new_desc) < max_size: + new_desc += word + else: + # At this point the next word *will* take us over + # MAX_SIZE, but we also want to ensure that its not + # a huge word. If it is add it anyway and we'll + # truncate later. + if len(new_desc) < min_size: + new_desc += word + break + + # Double check that we're not over the limit + if len(new_desc) > max_size: + new_desc = new_desc[:max_size] + + # We always add an ellipsis because at the very least + # we chopped mid paragraph. + description = new_desc.strip() + "…" + return description if description else None |