diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 5022808ea9..303419d281 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,7 +17,7 @@
from twisted.internet import defer
from synapse.api.constants import Membership
-from synapse.api.errors import AuthError, SynapseError
+from synapse.api.errors import AuthError, SynapseError, Codes, NotFoundError
from synapse.types import UserID, create_requester
from synapse.http.servlet import parse_json_object_from_request
@@ -113,12 +114,18 @@ class PurgeMediaCacheRestServlet(ClientV1RestServlet):
class PurgeHistoryRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
- "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
+ "/admin/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
super(PurgeHistoryRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.store = hs.get_datastore()
@defer.inlineCallbacks
def on_POST(self, request, room_id, event_id):
@@ -128,9 +135,93 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
if not is_admin:
raise AuthError(403, "You are not a server admin")
- yield self.handlers.message_handler.purge_history(room_id, event_id)
+ body = parse_json_object_from_request(request, allow_empty_body=True)
- defer.returnValue((200, {}))
+ delete_local_events = bool(body.get("delete_local_events", False))
+
+ # establish the topological ordering we should keep events from. The
+ # user can provide an event_id in the URL or the request body, or can
+ # provide a timestamp in the request body.
+ if event_id is None:
+ event_id = body.get('purge_up_to_event_id')
+
+ if event_id is not None:
+ event = yield self.store.get_event(event_id)
+
+ if event.room_id != room_id:
+ raise SynapseError(400, "Event is for wrong room.")
+
+ depth = event.depth
+ logger.info(
+ "[purge] purging up to depth %i (event_id %s)",
+ depth, event_id,
+ )
+ elif 'purge_up_to_ts' in body:
+ ts = body['purge_up_to_ts']
+ if not isinstance(ts, int):
+ raise SynapseError(
+ 400, "purge_up_to_ts must be an int",
+ errcode=Codes.BAD_JSON,
+ )
+
+ stream_ordering = (
+ yield self.store.find_first_stream_ordering_after_ts(ts)
+ )
+
+ (_, depth, _) = (
+ yield self.store.get_room_event_after_stream_ordering(
+ room_id, stream_ordering,
+ )
+ )
+ logger.info(
+ "[purge] purging up to depth %i (received_ts %i => "
+ "stream_ordering %i)",
+ depth, ts, stream_ordering,
+ )
+ else:
+ raise SynapseError(
+ 400,
+ "must specify purge_up_to_event_id or purge_up_to_ts",
+ errcode=Codes.BAD_JSON,
+ )
+
+ purge_id = yield self.handlers.message_handler.start_purge_history(
+ room_id, depth,
+ delete_local_events=delete_local_events,
+ )
+
+ defer.returnValue((200, {
+ "purge_id": purge_id,
+ }))
+
+
+class PurgeHistoryStatusRestServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/admin/purge_history_status/(?P<purge_id>[^/]+)"
+ )
+
+ def __init__(self, hs):
+ """
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ super(PurgeHistoryStatusRestServlet, self).__init__(hs)
+ self.handlers = hs.get_handlers()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, purge_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ purge_status = self.handlers.message_handler.get_purge_status(purge_id)
+ if purge_status is None:
+ raise NotFoundError("purge id '%s' not found" % purge_id)
+
+ defer.returnValue((200, purge_status.asdict()))
class DeactivateAccountRestServlet(ClientV1RestServlet):
@@ -171,6 +262,8 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
self.store = hs.get_datastore()
self.handlers = hs.get_handlers()
self.state = hs.get_state_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
@defer.inlineCallbacks
def on_POST(self, request, room_id):
@@ -203,8 +296,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
)
new_room_id = info["room_id"]
- msg_handler = self.handlers.message_handler
- yield msg_handler.create_and_send_nonmember_event(
+ yield self.event_creation_handler.create_and_send_nonmember_event(
room_creator_requester,
{
"type": "m.room.message",
@@ -230,7 +322,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
logger.info("Kicking %r from %r...", user_id, room_id)
target_requester = create_requester(user_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=room_id,
@@ -239,9 +331,9 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
ratelimit=False
)
- yield self.handlers.room_member_handler.forget(target_requester.user, room_id)
+ yield self.room_member_handler.forget(target_requester.user, room_id)
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
room_id=new_room_id,
@@ -289,6 +381,27 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
defer.returnValue((200, {"num_quarantined": num_quarantined}))
+class ListMediaInRoom(ClientV1RestServlet):
+ """Lists all of the media in a given room.
+ """
+ PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
+
+ def __init__(self, hs):
+ super(ListMediaInRoom, self).__init__(hs)
+ self.store = hs.get_datastore()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id):
+ requester = yield self.auth.get_user_by_req(request)
+ is_admin = yield self.auth.is_server_admin(requester.user)
+ if not is_admin:
+ raise AuthError(403, "You are not a server admin")
+
+ local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+
+ defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
+
+
class ResetPasswordRestServlet(ClientV1RestServlet):
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
@@ -479,6 +592,7 @@ class SearchUsersRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
WhoisRestServlet(hs).register(http_server)
PurgeMediaCacheRestServlet(hs).register(http_server)
+ PurgeHistoryStatusRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
PurgeHistoryRestServlet(hs).register(http_server)
UsersRestServlet(hs).register(http_server)
@@ -487,3 +601,4 @@ def register_servlets(hs, http_server):
SearchUsersRestServlet(hs).register(http_server)
ShutdownRoomRestServlet(hs).register(http_server)
QuarantineMediaInRoom(hs).register(http_server)
+ ListMediaInRoom(hs).register(http_server)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 5669ecb724..45844aa2d2 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -191,19 +191,25 @@ class LoginRestServlet(ClientV1RestServlet):
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":
- if 'medium' not in identifier or 'address' not in identifier:
+ address = identifier.get('address')
+ medium = identifier.get('medium')
+
+ if medium is None or address is None:
raise SynapseError(400, "Invalid thirdparty identifier")
- address = identifier['address']
- if identifier['medium'] == 'email':
+ if medium == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
- identifier['medium'], address
+ medium, address,
)
if not user_id:
+ logger.warn(
+ "unknown 3pid identifier medium %s, address %r",
+ medium, address,
+ )
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
identifier = {
diff --git a/synapse/rest/client/v1/register.py b/synapse/rest/client/v1/register.py
index 32ed1d3ab2..5c5fa8f7ab 100644
--- a/synapse/rest/client/v1/register.py
+++ b/synapse/rest/client/v1/register.py
@@ -70,10 +70,15 @@ class RegisterRestServlet(ClientV1RestServlet):
self.handlers = hs.get_handlers()
def on_GET(self, request):
+
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [
@@ -82,27 +87,34 @@ class RegisterRestServlet(ClientV1RestServlet):
LoginType.PASSWORD
]
},
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.RECAPTCHA,
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
}
- ]}
- )
+ ])
else:
- return (
- 200,
- {"flows": [
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if require_email or not require_msisdn:
+ flows.extend([
{
"type": LoginType.EMAIL_IDENTITY,
"stages": [
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
]
- },
+ }
+ ])
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([
{
"type": LoginType.PASSWORD
}
- ]}
- )
+ ])
+ return (200, {"flows": flows})
@defer.inlineCallbacks
def on_POST(self, request):
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 80989731fa..70d788deea 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -82,6 +83,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /room/$roomid/state/$eventtype
@@ -154,7 +157,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = yield self.handlers.room_member_handler.update_membership(
+ event = yield self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -162,15 +165,16 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- msg_handler = self.handlers.message_handler
- event, context = yield msg_handler.create_event(
+ event, context = yield self.event_creation_hander.create_event(
requester,
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
- yield msg_handler.send_nonmember_event(requester, event, context)
+ yield self.event_creation_hander.send_nonmember_event(
+ requester, event, context,
+ )
ret = {}
if event:
@@ -183,7 +187,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.event_creation_hander = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -195,15 +199,19 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event_dict = {
+ "type": event_type,
+ "content": content,
+ "room_id": room_id,
+ "sender": requester.user.to_string(),
+ }
+
+ if 'ts' in request.args and requester.app_service:
+ event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
+
+ event = yield self.event_creation_hander.create_and_send_nonmember_event(
requester,
- {
- "type": event_type,
- "content": content,
- "room_id": room_id,
- "sender": requester.user.to_string(),
- },
+ event_dict,
txn_id=txn_id,
)
@@ -222,7 +230,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
class JoinRoomAliasServlet(ClientV1RestServlet):
def __init__(self, hs):
super(JoinRoomAliasServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /join/$room_identifier[/$txn_id]
@@ -250,7 +258,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
- handler = self.handlers.room_member_handler
+ handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
@@ -259,7 +267,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
room_identifier,
))
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -487,13 +495,35 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
defer.returnValue((200, content))
-class RoomEventContext(ClientV1RestServlet):
+class RoomEventServlet(ClientV1RestServlet):
+ PATTERNS = client_path_patterns(
+ "/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(RoomEventServlet, self).__init__(hs)
+ self.clock = hs.get_clock()
+ self.event_handler = hs.get_event_handler()
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, room_id, event_id):
+ requester = yield self.auth.get_user_by_req(request)
+ event = yield self.event_handler.get_event(requester.user, event_id)
+
+ time_now = self.clock.time_msec()
+ if event:
+ defer.returnValue((200, serialize_event(event, time_now)))
+ else:
+ defer.returnValue((404, "Event not found."))
+
+
+class RoomEventContextServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns(
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
)
def __init__(self, hs):
- super(RoomEventContext, self).__init__(hs)
+ super(RoomEventContextServlet, self).__init__(hs)
self.clock = hs.get_clock()
self.handlers = hs.get_handlers()
@@ -533,7 +563,7 @@ class RoomEventContext(ClientV1RestServlet):
class RoomForgetRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomForgetRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
@@ -546,7 +576,7 @@ class RoomForgetRestServlet(ClientV1RestServlet):
allow_guest=False,
)
- yield self.handlers.room_member_handler.forget(
+ yield self.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
@@ -564,12 +594,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomMembershipRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.room_member_handler = hs.get_room_member_handler()
def register(self, http_server):
# /rooms/$roomid/[invite|join|leave]
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/"
- "(?P<membership_action>join|invite|leave|ban|unban|kick|forget)")
+ "(?P<membership_action>join|invite|leave|ban|unban|kick)")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
@@ -593,7 +623,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
content = {}
if membership_action == "invite" and self._has_3pid_invite_keys(content):
- yield self.handlers.room_member_handler.do_3pid_invite(
+ yield self.room_member_handler.do_3pid_invite(
room_id,
requester.user,
content["medium"],
@@ -615,7 +645,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if 'reason' in content and membership_action in ['kick', 'ban']:
event_content = {'reason': content['reason']}
- yield self.handlers.room_member_handler.update_membership(
+ yield self.room_member_handler.update_membership(
requester=requester,
target=target,
room_id=room_id,
@@ -643,6 +673,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomRedactEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
@@ -653,8 +684,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
requester = yield self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- msg_handler = self.handlers.message_handler
- event = yield msg_handler.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
@@ -803,4 +833,5 @@ def register_servlets(hs, http_server):
RoomTypingRestServlet(hs).register(http_server)
SearchRestServlet(hs).register(http_server)
JoinedRoomsRestServlet(hs).register(http_server)
- RoomEventContext(hs).register(http_server)
+ RoomEventServlet(hs).register(http_server)
+ RoomEventContextServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index 385a3ad2ec..30523995af 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
)
from synapse.util.async import run_on_reactor
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
logger = logging.getLogger(__name__)
@@ -47,6 +48,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -78,6 +84,11 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -217,6 +228,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'email', body['email']
)
@@ -255,6 +271,11 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.datastore.get_user_id_by_threepid(
'msisdn', msisdn
)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index e9d88a8895..0ba62bddc1 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -26,6 +26,7 @@ from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
)
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import check_3pid_allowed
from ._base import client_v2_patterns, interactive_auth_handler
@@ -70,6 +71,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
'id_server', 'client_secret', 'email', 'send_attempt'
])
+ if not check_3pid_allowed(self.hs, "email", body['email']):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'email', body['email']
)
@@ -105,6 +111,11 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
+ if not check_3pid_allowed(self.hs, "msisdn", msisdn):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
+ )
+
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
'msisdn', msisdn
)
@@ -172,7 +183,7 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_auth_handler()
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
- self.room_member_handler = hs.get_handlers().room_member_handler
+ self.room_member_handler = hs.get_room_member_handler()
self.device_handler = hs.get_device_handler()
self.macaroon_gen = hs.get_macaroon_generator()
@@ -305,31 +316,67 @@ class RegisterRestServlet(RestServlet):
if 'x_show_msisdn' in body and body['x_show_msisdn']:
show_msisdn = True
+ # FIXME: need a better error than "no auth flow found" for scenarios
+ # where we required 3PID for registration but the user didn't give one
+ require_email = 'email' in self.hs.config.registrations_require_3pid
+ require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
+
+ flows = []
if self.hs.config.enable_registration_captcha:
- flows = [
- [LoginType.RECAPTCHA],
- [LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.RECAPTCHA]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email:
+ flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN, LoginType.RECAPTCHA],
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
])
else:
- flows = [
- [LoginType.DUMMY],
- [LoginType.EMAIL_IDENTITY],
- ]
+ # only support 3PIDless registration if no 3PIDs are required
+ if not require_email and not require_msisdn:
+ flows.extend([[LoginType.DUMMY]])
+ # only support the email-only flow if we don't require MSISDN 3PIDs
+ if not require_msisdn:
+ flows.extend([[LoginType.EMAIL_IDENTITY]])
+
if show_msisdn:
+ # only support the MSISDN-only flow if we don't require email 3PIDs
+ if not require_email or require_msisdn:
+ flows.extend([[LoginType.MSISDN]])
+ # always let users provide both MSISDN & email
flows.extend([
- [LoginType.MSISDN],
- [LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
+ [LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
])
auth_result, params, session_id = yield self.auth_handler.check_auth(
flows, body, self.hs.get_ip_from_request(request)
)
+ # Check that we're not trying to register a denied 3pid.
+ #
+ # the user-facing checks will probably already have happened in
+ # /register/email/requestToken when we requested a 3pid, but that's not
+ # guaranteed.
+
+ if auth_result:
+ for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
+ if login_type in auth_result:
+ medium = auth_result[login_type]['medium']
+ address = auth_result[login_type]['address']
+
+ if not check_3pid_allowed(self.hs, medium, address):
+ raise SynapseError(
+ 403, "Third party identifier is not allowed",
+ Codes.THREEPID_DENIED,
+ )
+
if registered_user_id is not None:
logger.info(
"Already registered user ID %r for this session",
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index cc2842aa72..17e6079cba 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -93,6 +93,7 @@ class RemoteKey(Resource):
self.store = hs.get_datastore()
self.version_string = hs.version_string
self.clock = hs.get_clock()
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
def render_GET(self, request):
self.async_render_GET(request)
@@ -137,6 +138,13 @@ class RemoteKey(Resource):
logger.info("Handling query for keys %r", query)
store_queries = []
for server_name, key_ids in query.items():
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ logger.debug("Federation denied with %s", server_name)
+ continue
+
if not key_ids:
key_ids = (None,)
for key_id in key_ids:
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 95fa95fce3..e7ac01da01 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -70,38 +70,11 @@ def respond_with_file(request, media_type, file_path,
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
- request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
- if upload_name:
- if is_ascii(upload_name):
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename=%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
- else:
- request.setHeader(
- b"Content-Disposition",
- b"inline; filename*=utf-8''%s" % (
- urllib.quote(upload_name.encode("utf-8")),
- ),
- )
-
- # cache for at least a day.
- # XXX: we might want to turn this off for data we don't want to
- # recommend caching as it's sensitive or private - or at least
- # select private. don't bother setting Expires as all our
- # clients are smart enough to be happy with Cache-Control
- request.setHeader(
- b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
- )
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
- request.setHeader(
- b"Content-Length", b"%d" % (file_size,)
- )
+ add_file_headers(request, media_type, file_size, upload_name)
with open(file_path, "rb") as f:
yield logcontext.make_deferred_yieldable(
@@ -111,3 +84,118 @@ def respond_with_file(request, media_type, file_path,
finish_request(request)
else:
respond_404(request)
+
+
+def add_file_headers(request, media_type, file_size, upload_name):
+ """Adds the correct response headers in preparation for responding with the
+ media.
+
+ Args:
+ request (twisted.web.http.Request)
+ media_type (str): The media/content type.
+ file_size (int): Size in bytes of the media, if known.
+ upload_name (str): The name of the requested file, if any.
+ """
+ request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+
+ # cache for at least a day.
+ # XXX: we might want to turn this off for data we don't want to
+ # recommend caching as it's sensitive or private - or at least
+ # select private. don't bother setting Expires as all our
+ # clients are smart enough to be happy with Cache-Control
+ request.setHeader(
+ b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
+ )
+
+ request.setHeader(
+ b"Content-Length", b"%d" % (file_size,)
+ )
+
+
+@defer.inlineCallbacks
+def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
+ """Responds to the request with given responder. If responder is None then
+ returns 404.
+
+ Args:
+ request (twisted.web.http.Request)
+ responder (Responder|None)
+ media_type (str): The media/content type.
+ file_size (int|None): Size in bytes of the media. If not known it should be None
+ upload_name (str|None): The name of the requested file, if any.
+ """
+ if not responder:
+ respond_404(request)
+ return
+
+ add_file_headers(request, media_type, file_size, upload_name)
+ with responder:
+ yield responder.write_to_consumer(request)
+ finish_request(request)
+
+
+class Responder(object):
+ """Represents a response that can be streamed to the requester.
+
+ Responder is a context manager which *must* be used, so that any resources
+ held can be cleaned up.
+ """
+ def write_to_consumer(self, consumer):
+ """Stream response into consumer
+
+ Args:
+ consumer (IConsumer)
+
+ Returns:
+ Deferred: Resolves once the response has finished being written
+ """
+ pass
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ pass
+
+
+class FileInfo(object):
+ """Details about a requested/uploaded file.
+
+ Attributes:
+ server_name (str): The server name where the media originated from,
+ or None if local.
+ file_id (str): The local ID of the file. For local files this is the
+ same as the media_id
+ url_cache (bool): If the file is for the url preview cache
+ thumbnail (bool): Whether the file is a thumbnail or not.
+ thumbnail_width (int)
+ thumbnail_height (int)
+ thumbnail_method (str)
+ thumbnail_type (str): Content type of thumbnail, e.g. image/png
+ """
+ def __init__(self, server_name, file_id, url_cache=False,
+ thumbnail=False, thumbnail_width=None, thumbnail_height=None,
+ thumbnail_method=None, thumbnail_type=None):
+ self.server_name = server_name
+ self.file_id = file_id
+ self.url_cache = url_cache
+ self.thumbnail = thumbnail
+ self.thumbnail_width = thumbnail_width
+ self.thumbnail_height = thumbnail_height
+ self.thumbnail_method = thumbnail_method
+ self.thumbnail_type = thumbnail_type
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 6879249c8a..fe7e17596f 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -14,7 +14,7 @@
# limitations under the License.
import synapse.http.servlet
-from ._base import parse_media_id, respond_with_file, respond_404
+from ._base import parse_media_id, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler, set_cors_headers
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
def __init__(self, hs, media_repo):
Resource.__init__(self)
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
- self.store = hs.get_datastore()
- self.version_string = hs.version_string
+
+ # Both of these are expected by @request_handler()
self.clock = hs.get_clock()
+ self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
@@ -57,59 +57,16 @@ class DownloadResource(Resource):
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id, name)
+ yield self.media_repo.get_local_media(request, media_id, name)
else:
- yield self._respond_remote_file(
- request, server_name, media_id, name
- )
-
- @defer.inlineCallbacks
- def _respond_local_file(self, request, media_id, name):
- media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
- respond_404(request)
- return
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- upload_name = name if name else media_info["upload_name"]
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_filepath(media_id)
- else:
- file_path = self.filepaths.local_media_filepath(media_id)
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
-
- @defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id, name):
- # don't forward requests for remote media if allow_remote is false
- allow_remote = synapse.http.servlet.parse_boolean(
- request, "allow_remote", default=True)
- if not allow_remote:
- logger.info(
- "Rejecting request for remote media %s/%s due to allow_remote",
- server_name, media_id,
- )
- respond_404(request)
- return
-
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- media_type = media_info["media_type"]
- media_length = media_info["media_length"]
- filesystem_id = media_info["filesystem_id"]
- upload_name = name if name else media_info["upload_name"]
-
- file_path = self.filepaths.remote_media_filepath(
- server_name, filesystem_id
- )
-
- yield respond_with_file(
- request, media_type, file_path, media_length,
- upload_name=upload_name,
- )
+ allow_remote = synapse.http.servlet.parse_boolean(
+ request, "allow_remote", default=True)
+ if not allow_remote:
+ logger.info(
+ "Rejecting request for remote media %s/%s due to allow_remote",
+ server_name, media_id,
+ )
+ respond_404(request)
+ return
+
+ yield self.media_repo.get_remote_media(request, server_name, media_id, name)
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index eed9056a2f..bb79599379 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -18,6 +19,7 @@ import twisted.internet.error
import twisted.web.http
from twisted.web.resource import Resource
+from ._base import respond_404, FileInfo, respond_with_responder
from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
@@ -25,15 +27,18 @@ from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from .thumbnailer import Thumbnailer
+from .storage_provider import StorageProviderWrapper
+from .media_storage import MediaStorage
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
-from synapse.api.errors import SynapseError, HttpResponseException, \
- NotFoundError
+from synapse.api.errors import (
+ SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
+)
from synapse.util.async import Linearizer
from synapse.util.stringutils import is_ascii
-from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
+from synapse.util.logcontext import make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
import os
@@ -47,7 +52,7 @@ import urlparse
logger = logging.getLogger(__name__)
-UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
+UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
class MediaRepository(object):
@@ -63,96 +68,62 @@ class MediaRepository(object):
self.primary_base_path = hs.config.media_store_path
self.filepaths = MediaFilePaths(self.primary_base_path)
- self.backup_base_path = hs.config.backup_media_store_path
-
- self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
-
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
+ self.recently_accessed_locals = set()
- self.clock.looping_call(
- self._update_recently_accessed_remotes,
- UPDATE_RECENTLY_ACCESSED_REMOTES_TS
- )
+ self.federation_domain_whitelist = hs.config.federation_domain_whitelist
- @defer.inlineCallbacks
- def _update_recently_accessed_remotes(self):
- media = self.recently_accessed_remotes
- self.recently_accessed_remotes = set()
-
- yield self.store.update_cached_last_access_time(
- media, self.clock.time_msec()
- )
+ # List of StorageProviders where we should search for media and
+ # potentially upload to.
+ storage_providers = []
- @staticmethod
- def _makedirs(filepath):
- dirname = os.path.dirname(filepath)
- if not os.path.exists(dirname):
- os.makedirs(dirname)
+ for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
+ backend = clz(hs, provider_config)
+ provider = StorageProviderWrapper(
+ backend,
+ store_local=wrapper_config.store_local,
+ store_remote=wrapper_config.store_remote,
+ store_synchronous=wrapper_config.store_synchronous,
+ )
+ storage_providers.append(provider)
- @staticmethod
- def _write_file_synchronously(source, fname):
- """Write `source` to the path `fname` synchronously. Should be called
- from a thread.
+ self.media_storage = MediaStorage(
+ self.primary_base_path, self.filepaths, storage_providers,
+ )
- Args:
- source: A file like object to be written
- fname (str): Path to write to
- """
- MediaRepository._makedirs(fname)
- source.seek(0) # Ensure we read from the start of the file
- with open(fname, "wb") as f:
- shutil.copyfileobj(source, f)
+ self.clock.looping_call(
+ self._update_recently_accessed,
+ UPDATE_RECENTLY_ACCESSED_TS,
+ )
@defer.inlineCallbacks
- def write_to_file_and_backup(self, source, path):
- """Write `source` to the on disk media store, and also the backup store
- if configured.
-
- Args:
- source: A file like object that should be written
- path (str): Relative path to write file to
-
- Returns:
- Deferred[str]: the file path written to in the primary media store
- """
- fname = os.path.join(self.primary_base_path, path)
-
- # Write to the main repository
- yield make_deferred_yieldable(threads.deferToThread(
- self._write_file_synchronously, source, fname,
- ))
+ def _update_recently_accessed(self):
+ remote_media = self.recently_accessed_remotes
+ self.recently_accessed_remotes = set()
- # Write to backup repository
- yield self.copy_to_backup(path)
+ local_media = self.recently_accessed_locals
+ self.recently_accessed_locals = set()
- defer.returnValue(fname)
+ yield self.store.update_cached_last_access_time(
+ local_media, remote_media, self.clock.time_msec()
+ )
- @defer.inlineCallbacks
- def copy_to_backup(self, path):
- """Copy a file from the primary to backup media store, if configured.
+ def mark_recently_accessed(self, server_name, media_id):
+ """Mark the given media as recently accessed.
Args:
- path(str): Relative path to write file to
+ server_name (str|None): Origin server of media, or None if local
+ media_id (str): The media ID of the content
"""
- if self.backup_base_path:
- primary_fname = os.path.join(self.primary_base_path, path)
- backup_fname = os.path.join(self.backup_base_path, path)
-
- # We can either wait for successful writing to the backup repository
- # or write in the background and immediately return
- if self.synchronous_backup_media_store:
- yield make_deferred_yieldable(threads.deferToThread(
- shutil.copyfile, primary_fname, backup_fname,
- ))
- else:
- preserve_fn(threads.deferToThread)(
- shutil.copyfile, primary_fname, backup_fname,
- )
+ if server_name:
+ self.recently_accessed_remotes.add((server_name, media_id))
+ else:
+ self.recently_accessed_locals.add(media_id)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
@@ -171,10 +142,13 @@ class MediaRepository(object):
"""
media_id = random_string(24)
- fname = yield self.write_to_file_and_backup(
- content, self.filepaths.local_media_filepath_rel(media_id)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
)
+ fname = yield self.media_storage.store_file(content, file_info)
+
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
@@ -185,134 +159,275 @@ class MediaRepository(object):
media_length=content_length,
user_id=auth_user,
)
- media_info = {
- "media_type": media_type,
- "media_length": content_length,
- }
- yield self._generate_thumbnails(None, media_id, media_info)
+ yield self._generate_thumbnails(
+ None, media_id, media_id, media_type,
+ )
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@defer.inlineCallbacks
- def get_remote_media(self, server_name, media_id):
+ def get_local_media(self, request, media_id, name):
+ """Responds to reqests for local media, if exists, or returns 404.
+
+ Args:
+ request(twisted.web.http.Request)
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content.)
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ media_info = yield self.store.get_local_media(media_id)
+ if not media_info or media_info["quarantined_by"]:
+ respond_404(request)
+ return
+
+ self.mark_recently_accessed(None, media_id)
+
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ url_cache = media_info["url_cache"]
+
+ file_info = FileInfo(
+ None, media_id,
+ url_cache=url_cache,
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+
+ @defer.inlineCallbacks
+ def get_remote_media(self, request, server_name, media_id, name):
+ """Respond to requests for remote media.
+
+ Args:
+ request(twisted.web.http.Request)
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+ name (str|None): Optional name that, if specified, will be used as
+ the filename in the Content-Disposition header of the response.
+
+ Returns:
+ Deferred: Resolves once a response has successfully been written
+ to request
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ self.mark_recently_accessed(server_name, media_id)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
+ key = (server_name, media_id)
+ with (yield self.remote_media_linearizer.queue(key)):
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # We deliberately stream the file outside the lock
+ if responder:
+ media_type = media_info["media_type"]
+ media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
+ yield respond_with_responder(
+ request, responder, media_type, media_length, upload_name,
+ )
+ else:
+ respond_404(request)
+
+ @defer.inlineCallbacks
+ def get_remote_media_info(self, server_name, media_id):
+ """Gets the media info associated with the remote file, downloading
+ if necessary.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[dict]: The media_info of the file
+ """
+ if (
+ self.federation_domain_whitelist is not None and
+ server_name not in self.federation_domain_whitelist
+ ):
+ raise FederationDeniedError(server_name)
+
+ # We linearize here to ensure that we don't try and download remote
+ # media multiple times concurrently
key = (server_name, media_id)
with (yield self.remote_media_linearizer.queue(key)):
- media_info = yield self._get_remote_media_impl(server_name, media_id)
+ responder, media_info = yield self._get_remote_media_impl(
+ server_name, media_id,
+ )
+
+ # Ensure we actually use the responder so that it releases resources
+ if responder:
+ with responder:
+ pass
+
defer.returnValue(media_info)
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
+ """Looks for media in local cache, if not there then attempt to
+ download from remote server.
+
+ Args:
+ server_name (str): Remote server_name where the media originated.
+ media_id (str): The media ID of the content (as defined by the
+ remote server).
+
+ Returns:
+ Deferred[(Responder, media_info)]
+ """
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
- if not media_info:
- media_info = yield self._download_remote_file(
- server_name, media_id
- )
- elif media_info["quarantined_by"]:
- raise NotFoundError()
+
+ # file_id is the ID we use to track the file locally. If we've already
+ # seen the file then reuse the existing ID, otherwise genereate a new
+ # one.
+ if media_info:
+ file_id = media_info["filesystem_id"]
else:
- self.recently_accessed_remotes.add((server_name, media_id))
- yield self.store.update_cached_last_access_time(
- [(server_name, media_id)], self.clock.time_msec()
- )
- defer.returnValue(media_info)
+ file_id = random_string(24)
+
+ file_info = FileInfo(server_name, file_id)
+
+ # If we have an entry in the DB, try and look for it
+ if media_info:
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
+ raise NotFoundError()
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ defer.returnValue((responder, media_info))
+
+ # Failed to find the file anywhere, lets download it.
+
+ media_info = yield self._download_remote_file(
+ server_name, media_id, file_id
+ )
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ defer.returnValue((responder, media_info))
@defer.inlineCallbacks
- def _download_remote_file(self, server_name, media_id):
- file_id = random_string(24)
+ def _download_remote_file(self, server_name, media_id, file_id):
+ """Attempt to download the remote file from the given server name,
+ using the given file_id as the local id.
+
+ Args:
+ server_name (str): Originating server
+ media_id (str): The media ID of the content (as defined by the
+ remote server). This is different than the file_id, which is
+ locally generated.
+ file_id (str): Local file ID
+
+ Returns:
+ Deferred[MediaInfo]
+ """
- fpath = self.filepaths.remote_media_filepath_rel(
- server_name, file_id
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
)
- fname = os.path.join(self.primary_base_path, fpath)
- self._makedirs(fname)
- try:
- with open(fname, "wb") as f:
- request_path = "/".join((
- "/_matrix/media/v1/download", server_name, media_id,
- ))
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ request_path = "/".join((
+ "/_matrix/media/v1/download", server_name, media_id,
+ ))
+ try:
+ length, headers = yield self.client.get_file(
+ server_name, request_path, output_stream=f,
+ max_size=self.max_upload_size, args={
+ # tell the remote server to 404 if it doesn't
+ # recognise the server_name, to make sure we don't
+ # end up with a routing loop.
+ "allow_remote": "false",
+ }
+ )
+ except twisted.internet.error.DNSLookupError as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %r",
+ server_name, media_id, e)
+ raise NotFoundError()
+
+ except HttpResponseException as e:
+ logger.warn("HTTP error fetching remote media %s/%s: %s",
+ server_name, media_id, e.response)
+ if e.code == twisted.web.http.NOT_FOUND:
+ raise SynapseError.from_http_response_exception(e)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ except SynapseError:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise
+ except NotRetryingDestination:
+ logger.warn("Not retrying destination %r", server_name)
+ raise SynapseError(502, "Failed to fetch remote media")
+ except Exception:
+ logger.exception("Failed to fetch remote media %s/%s",
+ server_name, media_id)
+ raise SynapseError(502, "Failed to fetch remote media")
+
+ yield finish()
+
+ media_type = headers["Content-Type"][0]
+
+ time_now_ms = self.clock.time_msec()
+
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
try:
- length, headers = yield self.client.get_file(
- server_name, request_path, output_stream=f,
- max_size=self.max_upload_size, args={
- # tell the remote server to 404 if it doesn't
- # recognise the server_name, to make sure we don't
- # end up with a routing loop.
- "allow_remote": "false",
- }
- )
- except twisted.internet.error.DNSLookupError as e:
- logger.warn("HTTP error fetching remote media %s/%s: %r",
- server_name, media_id, e)
- raise NotFoundError()
-
- except HttpResponseException as e:
- logger.warn("HTTP error fetching remote media %s/%s: %s",
- server_name, media_id, e.response)
- if e.code == twisted.web.http.NOT_FOUND:
- raise SynapseError.from_http_response_exception(e)
- raise SynapseError(502, "Failed to fetch remote media")
-
- except SynapseError:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise
- except NotRetryingDestination:
- logger.warn("Not retrying destination %r", server_name)
- raise SynapseError(502, "Failed to fetch remote media")
- except Exception:
- logger.exception("Failed to fetch remote media %s/%s",
- server_name, media_id)
- raise SynapseError(502, "Failed to fetch remote media")
-
- yield self.copy_to_backup(fpath)
-
- media_type = headers["Content-Type"][0]
- time_now_ms = self.clock.time_msec()
-
- content_disposition = headers.get("Content-Disposition", None)
- if content_disposition:
- _, params = cgi.parse_header(content_disposition[0],)
- upload_name = None
-
- # First check if there is a valid UTF-8 filename
- upload_name_utf8 = params.get("filename*", None)
- if upload_name_utf8:
- if upload_name_utf8.lower().startswith("utf-8''"):
- upload_name = upload_name_utf8[7:]
-
- # If there isn't check for an ascii name.
- if not upload_name:
- upload_name_ascii = params.get("filename", None)
- if upload_name_ascii and is_ascii(upload_name_ascii):
- upload_name = upload_name_ascii
-
- if upload_name:
- upload_name = urlparse.unquote(upload_name)
- try:
- upload_name = upload_name.decode("utf-8")
- except UnicodeDecodeError:
- upload_name = None
- else:
- upload_name = None
-
- logger.info("Stored remote media in file %r", fname)
-
- yield self.store.store_cached_remote_media(
- origin=server_name,
- media_id=media_id,
- media_type=media_type,
- time_now_ms=self.clock.time_msec(),
- upload_name=upload_name,
- media_length=length,
- filesystem_id=file_id,
- )
- except Exception:
- os.remove(fname)
- raise
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
+ logger.info("Stored remote media in file %r", fname)
+
+ yield self.store.store_cached_remote_media(
+ origin=server_name,
+ media_id=media_id,
+ media_type=media_type,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=upload_name,
+ media_length=length,
+ filesystem_id=file_id,
+ )
media_info = {
"media_type": media_type,
@@ -323,7 +438,7 @@ class MediaRepository(object):
}
yield self._generate_thumbnails(
- server_name, media_id, media_info
+ server_name, media_id, file_id, media_type,
)
defer.returnValue(media_info)
@@ -357,8 +472,10 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
- t_method, t_type):
- input_path = self.filepaths.local_media_filepath(media_id)
+ t_method, t_type, url_cache):
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ None, media_id, url_cache=url_cache,
+ ))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -368,11 +485,19 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=None,
+ file_id=media_id,
+ url_cache=url_cache,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -390,7 +515,9 @@ class MediaRepository(object):
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=False,
+ ))
thumbnailer = Thumbnailer(input_path)
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -400,11 +527,18 @@ class MediaRepository(object):
if t_byte_source:
try:
- output_path = yield self.write_to_file_and_backup(
- t_byte_source,
- self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=media_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -421,31 +555,29 @@ class MediaRepository(object):
defer.returnValue(output_path)
@defer.inlineCallbacks
- def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
+ def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
+ url_cache=False):
"""Generate and store thumbnails for an image.
Args:
- server_name(str|None): The server name if remote media, else None if local
- media_id(str)
- media_info(dict)
- url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
+ server_name (str|None): The server name if remote media, else None if local
+ media_id (str): The media ID of the content. (This is the same as
+ the file_id for local content)
+ file_id (str): Local file ID
+ media_type (str): The content type of the file
+ url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
"""
- media_type = media_info["media_type"]
- file_id = media_info.get("filesystem_id")
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
- if server_name:
- input_path = self.filepaths.remote_media_filepath(server_name, file_id)
- elif url_cache:
- input_path = self.filepaths.url_cache_filepath(media_id)
- else:
- input_path = self.filepaths.local_media_filepath(media_id)
+ input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
+ server_name, file_id, url_cache=url_cache,
+ ))
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
@@ -472,20 +604,6 @@ class MediaRepository(object):
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
- # Work out the correct file name for thumbnail
- if server_name:
- file_path = self.filepaths.remote_media_thumbnail_rel(
- server_name, file_id, t_width, t_height, t_type, t_method
- )
- elif url_cache:
- file_path = self.filepaths.url_cache_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
- else:
- file_path = self.filepaths.local_media_thumbnail_rel(
- media_id, t_width, t_height, t_type, t_method
- )
-
# Generate the thumbnail
if t_method == "crop":
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
@@ -505,9 +623,19 @@ class MediaRepository(object):
continue
try:
- # Write to disk
- output_path = yield self.write_to_file_and_backup(
- t_byte_source, file_path,
+ file_info = FileInfo(
+ server_name=server_name,
+ file_id=file_id,
+ thumbnail=True,
+ thumbnail_width=t_width,
+ thumbnail_height=t_height,
+ thumbnail_method=t_method,
+ thumbnail_type=t_type,
+ url_cache=url_cache,
+ )
+
+ output_path = yield self.media_storage.store_file(
+ t_byte_source, file_info,
)
finally:
t_byte_source.close()
@@ -620,7 +748,11 @@ class MediaRepositoryResource(Resource):
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
- self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
+ self.putChild("thumbnail", ThumbnailResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
- self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
+ self.putChild("preview_url", PreviewUrlResource(
+ hs, media_repo, media_repo.media_storage,
+ ))
diff --git a/synapse/rest/media/v1/media_storage.py b/synapse/rest/media/v1/media_storage.py
new file mode 100644
index 0000000000..83471b3173
--- /dev/null
+++ b/synapse/rest/media/v1/media_storage.py
@@ -0,0 +1,259 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vecotr Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+from twisted.protocols.basic import FileSender
+
+from ._base import Responder
+
+from synapse.util.file_consumer import BackgroundFileConsumer
+from synapse.util.logcontext import make_deferred_yieldable
+
+import contextlib
+import os
+import logging
+import shutil
+import sys
+
+
+logger = logging.getLogger(__name__)
+
+
+class MediaStorage(object):
+ """Responsible for storing/fetching files from local sources.
+
+ Args:
+ local_media_directory (str): Base path where we store media on disk
+ filepaths (MediaFilePaths)
+ storage_providers ([StorageProvider]): List of StorageProvider that are
+ used to fetch and store files.
+ """
+
+ def __init__(self, local_media_directory, filepaths, storage_providers):
+ self.local_media_directory = local_media_directory
+ self.filepaths = filepaths
+ self.storage_providers = storage_providers
+
+ @defer.inlineCallbacks
+ def store_file(self, source, file_info):
+ """Write `source` to the on disk media store, and also any other
+ configured storage providers
+
+ Args:
+ source: A file like object that should be written
+ file_info (FileInfo): Info about the file to store
+
+ Returns:
+ Deferred[str]: the file path written to in the primary media store
+ """
+
+ with self.store_into_file(file_info) as (f, fname, finish_cb):
+ # Write to the main repository
+ yield make_deferred_yieldable(threads.deferToThread(
+ _write_file_synchronously, source, f,
+ ))
+ yield finish_cb()
+
+ defer.returnValue(fname)
+
+ @contextlib.contextmanager
+ def store_into_file(self, file_info):
+ """Context manager used to get a file like object to write into, as
+ described by file_info.
+
+ Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
+ like object that can be written to, fname is the absolute path of file
+ on disk, and finish_cb is a function that returns a Deferred.
+
+ fname can be used to read the contents from after upload, e.g. to
+ generate thumbnails.
+
+ finish_cb must be called and waited on after the file has been
+ successfully been written to. Should not be called if there was an
+ error.
+
+ Args:
+ file_info (FileInfo): Info about the file to store
+
+ Example:
+
+ with media_storage.store_into_file(info) as (f, fname, finish_cb):
+ # .. write into f ...
+ yield finish_cb()
+ """
+
+ path = self._file_info_to_path(file_info)
+ fname = os.path.join(self.local_media_directory, path)
+
+ dirname = os.path.dirname(fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ finished_called = [False]
+
+ @defer.inlineCallbacks
+ def finish():
+ for provider in self.storage_providers:
+ yield provider.store_file(path, file_info)
+
+ finished_called[0] = True
+
+ try:
+ with open(fname, "wb") as f:
+ yield f, fname, finish
+ except Exception:
+ t, v, tb = sys.exc_info()
+ try:
+ os.remove(fname)
+ except Exception:
+ pass
+ raise t, v, tb
+
+ if not finished_called:
+ raise Exception("Finished callback not called")
+
+ @defer.inlineCallbacks
+ def fetch_media(self, file_info):
+ """Attempts to fetch media described by file_info from the local cache
+ and configured storage providers.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[Responder|None]: Returns a Responder if the file was found,
+ otherwise None.
+ """
+
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(FileResponder(open(local_path, "rb")))
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ defer.returnValue(res)
+
+ defer.returnValue(None)
+
+ @defer.inlineCallbacks
+ def ensure_media_is_in_local_cache(self, file_info):
+ """Ensures that the given file is in the local cache. Attempts to
+ download it from storage providers if it isn't.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ Deferred[str]: Full path to local file
+ """
+ path = self._file_info_to_path(file_info)
+ local_path = os.path.join(self.local_media_directory, path)
+ if os.path.exists(local_path):
+ defer.returnValue(local_path)
+
+ dirname = os.path.dirname(local_path)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ for provider in self.storage_providers:
+ res = yield provider.fetch(path, file_info)
+ if res:
+ with res:
+ consumer = BackgroundFileConsumer(open(local_path, "w"))
+ yield res.write_to_consumer(consumer)
+ yield consumer.wait()
+ defer.returnValue(local_path)
+
+ raise Exception("file could not be found")
+
+ def _file_info_to_path(self, file_info):
+ """Converts file_info into a relative path.
+
+ The path is suitable for storing files under a directory, e.g. used to
+ store files on local FS under the base media repository directory.
+
+ Args:
+ file_info (FileInfo)
+
+ Returns:
+ str
+ """
+ if file_info.url_cache:
+ if file_info.thumbnail:
+ return self.filepaths.url_cache_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method,
+ )
+ return self.filepaths.url_cache_filepath_rel(file_info.file_id)
+
+ if file_info.server_name:
+ if file_info.thumbnail:
+ return self.filepaths.remote_media_thumbnail_rel(
+ server_name=file_info.server_name,
+ file_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.remote_media_filepath_rel(
+ file_info.server_name, file_info.file_id,
+ )
+
+ if file_info.thumbnail:
+ return self.filepaths.local_media_thumbnail_rel(
+ media_id=file_info.file_id,
+ width=file_info.thumbnail_width,
+ height=file_info.thumbnail_height,
+ content_type=file_info.thumbnail_type,
+ method=file_info.thumbnail_method
+ )
+ return self.filepaths.local_media_filepath_rel(
+ file_info.file_id,
+ )
+
+
+def _write_file_synchronously(source, dest):
+ """Write `source` to the file like `dest` synchronously. Should be called
+ from a thread.
+
+ Args:
+ source: A file like object that's to be written
+ dest: A file like object to be written to
+ """
+ source.seek(0) # Ensure we read from the start of the file
+ shutil.copyfileobj(source, dest)
+
+
+class FileResponder(Responder):
+ """Wraps an open file that can be sent to a request.
+
+ Args:
+ open_file (file): A file like object to be streamed ot the client,
+ is closed when finished streaming.
+ """
+ def __init__(self, open_file):
+ self.open_file = open_file
+
+ def write_to_consumer(self, consumer):
+ return FileSender().beginFileTransfer(self.open_file, consumer)
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.open_file.close()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index a413cb6226..0fc21540c6 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -12,11 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import cgi
+import datetime
+import errno
+import fnmatch
+import itertools
+import logging
+import os
+import re
+import shutil
+import sys
+import traceback
+import simplejson as json
+import urlparse
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
+from ._base import FileInfo
+
from synapse.api.errors import (
SynapseError, Codes,
)
@@ -31,25 +46,13 @@ from synapse.http.server import (
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
-import os
-import re
-import fnmatch
-import cgi
-import simplejson as json
-import urlparse
-import itertools
-import datetime
-import errno
-import shutil
-
-import logging
logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.auth = hs.get_auth()
@@ -62,6 +65,7 @@ class PreviewUrlResource(Resource):
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path
+ self.media_storage = media_storage
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
@@ -182,8 +186,10 @@ class PreviewUrlResource(Resource):
logger.debug("got media_info of '%s'" % media_info)
if _is_media(media_info['media_type']):
+ file_id = media_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, media_info['filesystem_id'], media_info, url_cache=True,
+ None, file_id, file_id, media_info["media_type"],
+ url_cache=True,
)
og = {
@@ -228,8 +234,10 @@ class PreviewUrlResource(Resource):
if _is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
+ file_id = image_info['filesystem_id']
dims = yield self.media_repo._generate_thumbnails(
- None, image_info['filesystem_id'], image_info, url_cache=True,
+ None, file_id, file_id, image_info["media_type"],
+ url_cache=True,
)
if dims:
og["og:image:width"] = dims['width']
@@ -273,21 +281,34 @@ class PreviewUrlResource(Resource):
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
- fpath = self.filepaths.url_cache_filepath_rel(file_id)
- fname = os.path.join(self.primary_base_path, fpath)
- self.media_repo._makedirs(fname)
+ file_info = FileInfo(
+ server_name=None,
+ file_id=file_id,
+ url_cache=True,
+ )
- try:
- with open(fname, "wb") as f:
+ with self.media_storage.store_into_file(file_info) as (f, fname, finish):
+ try:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
+ except Exception as e:
# FIXME: pass through 404s and other error messages nicely
+ logger.warn("Error downloading %s: %r", url, e)
+ raise SynapseError(
+ 500, "Failed to download content: %s" % (
+ traceback.format_exception_only(sys.exc_type, e),
+ ),
+ Codes.UNKNOWN,
+ )
+ yield finish()
- yield self.media_repo.copy_to_backup(fpath)
-
- media_type = headers["Content-Type"][0]
+ try:
+ if "Content-Type" in headers:
+ media_type = headers["Content-Type"][0]
+ else:
+ media_type = "application/octet-stream"
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
@@ -327,11 +348,11 @@ class PreviewUrlResource(Resource):
)
except Exception as e:
- os.remove(fname)
- raise SynapseError(
- 500, ("Failed to download content: %s" % e),
- Codes.UNKNOWN
- )
+ logger.error("Error handling downloaded %s: %r", url, e)
+ # TODO: we really ought to delete the downloaded file in this
+ # case, since we won't have recorded it in the db, and will
+ # therefore not expire it.
+ raise
defer.returnValue({
"media_type": media_type,
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
new file mode 100644
index 0000000000..c188192f2b
--- /dev/null
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 New Vector Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer, threads
+
+from .media_storage import FileResponder
+
+from synapse.config._base import Config
+from synapse.util.logcontext import preserve_fn
+
+import logging
+import os
+import shutil
+
+
+logger = logging.getLogger(__name__)
+
+
+class StorageProvider(object):
+ """A storage provider is a service that can store uploaded media and
+ retrieve them.
+ """
+ def store_file(self, path, file_info):
+ """Store the file described by file_info. The actual contents can be
+ retrieved by reading the file in file_info.upload_path.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred
+ """
+ pass
+
+ def fetch(self, path, file_info):
+ """Attempt to fetch the file described by file_info and stream it
+ into writer.
+
+ Args:
+ path (str): Relative path of file in local cache
+ file_info (FileInfo)
+
+ Returns:
+ Deferred(Responder): Returns a Responder if the provider has the file,
+ otherwise returns None.
+ """
+ pass
+
+
+class StorageProviderWrapper(StorageProvider):
+ """Wraps a storage provider and provides various config options
+
+ Args:
+ backend (StorageProvider)
+ store_local (bool): Whether to store new local files or not.
+ store_synchronous (bool): Whether to wait for file to be successfully
+ uploaded, or todo the upload in the backgroud.
+ store_remote (bool): Whether remote media should be uploaded
+ """
+ def __init__(self, backend, store_local, store_synchronous, store_remote):
+ self.backend = backend
+ self.store_local = store_local
+ self.store_synchronous = store_synchronous
+ self.store_remote = store_remote
+
+ def store_file(self, path, file_info):
+ if not file_info.server_name and not self.store_local:
+ return defer.succeed(None)
+
+ if file_info.server_name and not self.store_remote:
+ return defer.succeed(None)
+
+ if self.store_synchronous:
+ return self.backend.store_file(path, file_info)
+ else:
+ # TODO: Handle errors.
+ preserve_fn(self.backend.store_file)(path, file_info)
+ return defer.succeed(None)
+
+ def fetch(self, path, file_info):
+ return self.backend.fetch(path, file_info)
+
+
+class FileStorageProviderBackend(StorageProvider):
+ """A storage provider that stores files in a directory on a filesystem.
+
+ Args:
+ hs (HomeServer)
+ config: The config returned by `parse_config`.
+ """
+
+ def __init__(self, hs, config):
+ self.cache_directory = hs.config.media_store_path
+ self.base_directory = config
+
+ def store_file(self, path, file_info):
+ """See StorageProvider.store_file"""
+
+ primary_fname = os.path.join(self.cache_directory, path)
+ backup_fname = os.path.join(self.base_directory, path)
+
+ dirname = os.path.dirname(backup_fname)
+ if not os.path.exists(dirname):
+ os.makedirs(dirname)
+
+ return threads.deferToThread(
+ shutil.copyfile, primary_fname, backup_fname,
+ )
+
+ def fetch(self, path, file_info):
+ """See StorageProvider.fetch"""
+
+ backup_fname = os.path.join(self.base_directory, path)
+ if os.path.isfile(backup_fname):
+ return FileResponder(open(backup_fname, "rb"))
+
+ @staticmethod
+ def parse_config(config):
+ """Called on startup to parse config supplied. This should parse
+ the config and raise if there is a problem.
+
+ The returned value is passed into the constructor.
+
+ In this case we only care about a single param, the directory, so let's
+ just pull that out.
+ """
+ return Config.ensure_directory(config["directory"])
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 68d56b2b10..58ada49711 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -14,7 +14,10 @@
# limitations under the License.
-from ._base import parse_media_id, respond_404, respond_with_file
+from ._base import (
+ parse_media_id, respond_404, respond_with_file, FileInfo,
+ respond_with_responder,
+)
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler, set_cors_headers
@@ -30,12 +33,12 @@ logger = logging.getLogger(__name__)
class ThumbnailResource(Resource):
isLeaf = True
- def __init__(self, hs, media_repo):
+ def __init__(self, hs, media_repo, media_storage):
Resource.__init__(self)
self.store = hs.get_datastore()
- self.filepaths = media_repo.filepaths
self.media_repo = media_repo
+ self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
@@ -64,6 +67,7 @@ class ThumbnailResource(Resource):
yield self._respond_local_thumbnail(
request, media_id, width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(None, media_id)
else:
if self.dynamic_thumbnails:
yield self._select_or_generate_remote_thumbnail(
@@ -75,20 +79,20 @@ class ThumbnailResource(Resource):
request, server_name, media_id,
width, height, method, m_type
)
+ self.media_repo.mark_recently_accessed(server_name, media_id)
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
method, m_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
@@ -96,42 +100,39 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
-
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path)
- else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
)
+ t_type = file_info.thumbnail_type
+ t_length = thumbnail_info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
+ else:
+ logger.info("Couldn't find any generated thumbnails")
+ respond_404(request)
+
@defer.inlineCallbacks
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
desired_height, desired_method,
desired_type):
media_info = yield self.store.get_local_media(media_id)
- if not media_info or media_info["quarantined_by"]:
+ if not media_info:
+ respond_404(request)
+ return
+ if media_info["quarantined_by"]:
+ logger.info("Media is quarantined")
respond_404(request)
return
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.local_media_filepath(media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
@@ -141,46 +142,43 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- if media_info["url_cache"]:
- # TODO: Check the file still exists, if it doesn't we can redownload
- # it from the url `media_info["url_cache"]`
- file_path = self.filepaths.url_cache_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- else:
- file_path = self.filepaths.local_media_thumbnail(
- media_id, desired_width, desired_height, desired_type,
- desired_method,
- )
- yield respond_with_file(request, desired_type, file_path)
- return
-
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ file_info = FileInfo(
+ server_name=None, file_id=media_id,
+ url_cache=media_info["url_cache"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_local_exact_thumbnail(
- media_id, desired_width, desired_height, desired_method, desired_type
+ media_id, desired_width, desired_height, desired_method, desired_type,
+ url_cache=media_info["url_cache"],
)
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -195,14 +193,24 @@ class ThumbnailResource(Resource):
t_type = info["thumbnail_type"] == desired_type
if t_w and t_h and t_method and t_type:
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, desired_width, desired_height,
- desired_type, desired_method,
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=info["thumbnail_width"],
+ thumbnail_height=info["thumbnail_height"],
+ thumbnail_type=info["thumbnail_type"],
+ thumbnail_method=info["thumbnail_method"],
)
- yield respond_with_file(request, desired_type, file_path)
- return
- logger.debug("We don't have a local thumbnail of that size. Generating")
+ t_type = file_info.thumbnail_type
+ t_length = info["thumbnail_length"]
+
+ responder = yield self.media_storage.fetch_media(file_info)
+ if responder:
+ yield respond_with_responder(request, responder, t_type, t_length)
+ return
+
+ logger.debug("We don't have a thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
@@ -213,22 +221,16 @@ class ThumbnailResource(Resource):
if file_path:
yield respond_with_file(request, desired_type, file_path)
else:
- yield self._respond_default_thumbnail(
- request, media_info, desired_width, desired_height,
- desired_method, desired_type,
- )
+ logger.warn("Failed to generate thumbnail")
+ respond_404(request)
@defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
- # We should proxy the thumbnail from the remote server instead.
- media_info = yield self.media_repo.get_remote_media(server_name, media_id)
-
- # if media_info["media_type"] == "image/svg+xml":
- # file_path = self.filepaths.remote_media_filepath(server_name, media_id)
- # yield respond_with_file(request, media_info["media_type"], file_path)
- # return
+ # We should proxy the thumbnail from the remote server instead of
+ # downloading the remote file and generating our own thumbnails.
+ media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -238,59 +240,23 @@ class ThumbnailResource(Resource):
thumbnail_info = self._select_thumbnail(
width, height, method, m_type, thumbnail_infos
)
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- file_id = thumbnail_info["filesystem_id"]
+ file_info = FileInfo(
+ server_name=server_name, file_id=media_info["filesystem_id"],
+ thumbnail=True,
+ thumbnail_width=thumbnail_info["thumbnail_width"],
+ thumbnail_height=thumbnail_info["thumbnail_height"],
+ thumbnail_type=thumbnail_info["thumbnail_type"],
+ thumbnail_method=thumbnail_info["thumbnail_method"],
+ )
+
+ t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
- file_path = self.filepaths.remote_media_thumbnail(
- server_name, file_id, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
+ responder = yield self.media_storage.fetch_media(file_info)
+ yield respond_with_responder(request, responder, t_type, t_length)
else:
- yield self._respond_default_thumbnail(
- request, media_info, width, height, method, m_type,
- )
-
- @defer.inlineCallbacks
- def _respond_default_thumbnail(self, request, media_info, width, height,
- method, m_type):
- # XXX: how is this meant to work? store.get_default_thumbnails
- # appears to always return [] so won't this always 404?
- media_type = media_info["media_type"]
- top_level_type = media_type.split("/")[0]
- sub_type = media_type.split("/")[-1].split(";")[0]
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, sub_type,
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- top_level_type, "_default",
- )
- if not thumbnail_infos:
- thumbnail_infos = yield self.store.get_default_thumbnails(
- "_default", "_default",
- )
- if not thumbnail_infos:
+ logger.info("Failed to find any generated thumbnails")
respond_404(request)
- return
-
- thumbnail_info = self._select_thumbnail(
- width, height, "crop", m_type, thumbnail_infos
- )
-
- t_width = thumbnail_info["thumbnail_width"]
- t_height = thumbnail_info["thumbnail_height"]
- t_type = thumbnail_info["thumbnail_type"]
- t_method = thumbnail_info["thumbnail_method"]
- t_length = thumbnail_info["thumbnail_length"]
-
- file_path = self.filepaths.default_thumbnail(
- top_level_type, sub_type, t_width, t_height, t_type, t_method,
- )
- yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):
|