diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index b2257b749d..2444f27366 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -20,14 +20,32 @@ from synapse.types import UserID
from base import ClientV1RestServlet, client_path_pattern
import simplejson as json
+import urllib
+
+import logging
+from saml2 import BINDING_HTTP_POST
+from saml2 import config
+from saml2.client import Saml2Client
+
+
+logger = logging.getLogger(__name__)
class LoginRestServlet(ClientV1RestServlet):
PATTERN = client_path_pattern("/login$")
PASS_TYPE = "m.login.password"
+ SAML2_TYPE = "m.login.saml2"
+
+ def __init__(self, hs):
+ super(LoginRestServlet, self).__init__(hs)
+ self.idp_redirect_url = hs.config.saml2_idp_redirect_url
+ self.saml2_enabled = hs.config.saml2_enabled
def on_GET(self, request):
- return (200, {"flows": [{"type": LoginRestServlet.PASS_TYPE}]})
+ flows = [{"type": LoginRestServlet.PASS_TYPE}]
+ if self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SAML2_TYPE})
+ return (200, {"flows": flows})
def on_OPTIONS(self, request):
return (200, {})
@@ -39,6 +57,16 @@ class LoginRestServlet(ClientV1RestServlet):
if login_submission["type"] == LoginRestServlet.PASS_TYPE:
result = yield self.do_password_login(login_submission)
defer.returnValue(result)
+ elif self.saml2_enabled and (login_submission["type"] ==
+ LoginRestServlet.SAML2_TYPE):
+ relay_state = ""
+ if "relay_state" in login_submission:
+ relay_state = "&RelayState="+urllib.quote(
+ login_submission["relay_state"])
+ result = {
+ "uri": "%s%s" % (self.idp_redirect_url, relay_state)
+ }
+ defer.returnValue((200, result))
else:
raise SynapseError(400, "Bad login type.")
except KeyError:
@@ -46,17 +74,24 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
- if not login_submission["user"].startswith('@'):
- login_submission["user"] = UserID.create(
- login_submission["user"], self.hs.hostname).to_string()
+ if 'medium' in login_submission and 'address' in login_submission:
+ user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
+ login_submission['medium'], login_submission['address']
+ )
+ else:
+ user_id = login_submission['user']
+
+ if not user_id.startswith('@'):
+ user_id = UserID.create(
+ user_id, self.hs.hostname
+ ).to_string()
- handler = self.handlers.login_handler
- token = yield handler.login(
- user=login_submission["user"],
+ user_id, token = yield self.handlers.auth_handler.login_with_password(
+ user_id=user_id,
password=login_submission["password"])
result = {
- "user_id": login_submission["user"], # may have changed
+ "user_id": user_id, # may have changed
"access_token": token,
"home_server": self.hs.hostname,
}
@@ -94,6 +129,49 @@ class PasswordResetRestServlet(ClientV1RestServlet):
)
+class SAML2RestServlet(ClientV1RestServlet):
+ PATTERN = client_path_pattern("/login/saml2")
+
+ def __init__(self, hs):
+ super(SAML2RestServlet, self).__init__(hs)
+ self.sp_config = hs.config.saml2_config_path
+
+ @defer.inlineCallbacks
+ def on_POST(self, request):
+ saml2_auth = None
+ try:
+ conf = config.SPConfig()
+ conf.load_file(self.sp_config)
+ SP = Saml2Client(conf)
+ saml2_auth = SP.parse_authn_request_response(
+ request.args['SAMLResponse'][0], BINDING_HTTP_POST)
+ except Exception, e: # Not authenticated
+ logger.exception(e)
+ if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
+ username = saml2_auth.name_id.text
+ handler = self.handlers.registration_handler
+ (user_id, token) = yield handler.register_saml2(username)
+ # Forward to the RelayState callback along with ava
+ if 'RelayState' in request.args:
+ request.redirect(urllib.unquote(
+ request.args['RelayState'][0]) +
+ '?status=authenticated&access_token=' +
+ token + '&user_id=' + user_id + '&ava=' +
+ urllib.quote(json.dumps(saml2_auth.ava)))
+ request.finish()
+ defer.returnValue(None)
+ defer.returnValue((200, {"status": "authenticated",
+ "user_id": user_id, "token": token,
+ "ava": saml2_auth.ava}))
+ elif 'RelayState' in request.args:
+ request.redirect(urllib.unquote(
+ request.args['RelayState'][0]) +
+ '?status=not_authenticated')
+ request.finish()
+ defer.returnValue(None)
+ defer.returnValue((200, {"status": "not_authenticated"}))
+
+
def _parse_json(request):
try:
content = json.loads(request.content.read())
@@ -106,4 +184,6 @@ def _parse_json(request):
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
+ if hs.config.saml2_enabled:
+ SAML2RestServlet(hs).register(http_server)
# TODO PasswordResetRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 0346afb1b4..b4a70cba99 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -412,6 +412,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
if "user_id" not in content:
raise SynapseError(400, "Missing user_id key.")
state_key = content["user_id"]
+ # make sure it looks like a user ID; it'll throw if it's invalid.
+ UserID.from_string(state_key)
if membership_action == "kick":
membership_action = "leave"
diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py
index 7d1aff4307..33f961e898 100644
--- a/synapse/rest/client/v2_alpha/__init__.py
+++ b/synapse/rest/client/v2_alpha/__init__.py
@@ -18,7 +18,9 @@ from . import (
filter,
account,
register,
- auth
+ auth,
+ receipts,
+ keys,
)
from synapse.http.server import JsonResource
@@ -38,3 +40,5 @@ class ClientV2AlphaRestResource(JsonResource):
account.register_servlets(hs, client_resource)
register.register_servlets(hs, client_resource)
auth.register_servlets(hs, client_resource)
+ receipts.register_servlets(hs, client_resource)
+ keys.register_servlets(hs, client_resource)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index b082140f1f..522a312c9e 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -36,7 +36,6 @@ class PasswordRestServlet(RestServlet):
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_handlers().auth_handler
- self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
@@ -47,7 +46,7 @@ class PasswordRestServlet(RestServlet):
authed, result, params = yield self.auth_handler.check_auth([
[LoginType.PASSWORD],
[LoginType.EMAIL_IDENTITY]
- ], body)
+ ], body, self.hs.get_ip_from_request(request))
if not authed:
defer.returnValue((401, result))
@@ -79,8 +78,8 @@ class PasswordRestServlet(RestServlet):
raise SynapseError(400, "", Codes.MISSING_PARAM)
new_password = params['new_password']
- yield self.login_handler.set_password(
- user_id, new_password, None
+ yield self.auth_handler.set_password(
+ user_id, new_password
)
defer.returnValue((200, {}))
@@ -95,7 +94,6 @@ class ThreepidRestServlet(RestServlet):
def __init__(self, hs):
super(ThreepidRestServlet, self).__init__()
self.hs = hs
- self.login_handler = hs.get_handlers().login_handler
self.identity_handler = hs.get_handlers().identity_handler
self.auth = hs.get_auth()
@@ -135,7 +133,7 @@ class ThreepidRestServlet(RestServlet):
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
raise SynapseError(500, "Invalid response from ID Server")
- yield self.login_handler.add_threepid(
+ yield self.auth_handler.add_threepid(
auth_user.to_string(),
threepid['medium'],
threepid['address'],
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
new file mode 100644
index 0000000000..718928eedd
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -0,0 +1,316 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import SynapseError
+from synapse.http.servlet import RestServlet
+from synapse.types import UserID
+from syutil.jsonutil import encode_canonical_json
+
+from ._base import client_v2_pattern
+
+import simplejson as json
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class KeyUploadServlet(RestServlet):
+ """
+ POST /keys/upload/<device_id> HTTP/1.1
+ Content-Type: application/json
+
+ {
+ "device_keys": {
+ "user_id": "<user_id>",
+ "device_id": "<device_id>",
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha256",
+ ]
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures:" {
+ "<user_id>" {
+ "<algorithm>:<device_id>": "<signature_base64>"
+ } } },
+ "one_time_keys": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ },
+ }
+ """
+ PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
+
+ def __init__(self, hs):
+ super(KeyUploadServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.auth = hs.get_auth()
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ user_id = auth_user.to_string()
+ # TODO: Check that the device_id matches that in the authentication
+ # or derive the device_id from the authentication instead.
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ time_now = self.clock.time_msec()
+
+ # TODO: Validate the JSON to make sure it has the right keys.
+ device_keys = body.get("device_keys", None)
+ if device_keys:
+ logger.info(
+ "Updating device_keys for device %r for user %r at %d",
+ device_id, auth_user, time_now
+ )
+ # TODO: Sign the JSON with the server key
+ yield self.store.set_e2e_device_keys(
+ user_id, device_id, time_now,
+ encode_canonical_json(device_keys)
+ )
+
+ one_time_keys = body.get("one_time_keys", None)
+ if one_time_keys:
+ logger.info(
+ "Adding %d one_time_keys for device %r for user %r at %d",
+ len(one_time_keys), device_id, user_id, time_now
+ )
+ key_list = []
+ for key_id, key_json in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, encode_canonical_json(key_json)
+ ))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, key_list
+ )
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ user_id = auth_user.to_string()
+
+ result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
+ defer.returnValue((200, {"one_time_key_counts": result}))
+
+
+class KeyQueryServlet(RestServlet):
+ """
+ GET /keys/query/<user_id> HTTP/1.1
+
+ GET /keys/query/<user_id>/<device_id> HTTP/1.1
+
+ POST /keys/query HTTP/1.1
+ Content-Type: application/json
+ {
+ "device_keys": {
+ "<user_id>": ["<device_id>"]
+ } }
+
+ HTTP/1.1 200 OK
+ {
+ "device_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "user_id": "<user_id>", // Duplicated to be signed
+ "device_id": "<device_id>", // Duplicated to be signed
+ "valid_until_ts": <millisecond_timestamp>,
+ "algorithms": [ // List of supported algorithms
+ "m.olm.curve25519-aes-sha256",
+ ],
+ "keys": { // Must include a ed25519 signing key
+ "<algorithm>:<key_id>": "<key_base64>",
+ },
+ "signatures:" {
+ // Must be signed with device's ed25519 key
+ "<user_id>/<device_id>": {
+ "<algorithm>:<key_id>": "<signature_base64>"
+ }
+ // Must be signed by this server.
+ "<server_name>": {
+ "<algorithm>:<key_id>": "<signature_base64>"
+ } } } } } }
+ """
+
+ PATTERN = client_v2_pattern(
+ "/keys/query(?:"
+ "/(?P<user_id>[^/]*)(?:"
+ "/(?P<device_id>[^/]*)"
+ ")?"
+ ")?"
+ )
+
+ def __init__(self, hs):
+ super(KeyQueryServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.federation = hs.get_replication_layer()
+ self.is_mine = hs.is_mine
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id, device_id):
+ yield self.auth.get_user_by_req(request)
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ result = yield self.handle_request(body)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, device_id):
+ auth_user, client_info = yield self.auth.get_user_by_req(request)
+ auth_user_id = auth_user.to_string()
+ user_id = user_id if user_id else auth_user_id
+ device_ids = [device_id] if device_id else []
+ result = yield self.handle_request(
+ {"device_keys": {user_id: device_ids}}
+ )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def handle_request(self, body):
+ local_query = []
+ remote_queries = {}
+ for user_id, device_ids in body.get("device_keys", {}).items():
+ user = UserID.from_string(user_id)
+ if self.is_mine(user):
+ if not device_ids:
+ local_query.append((user_id, None))
+ else:
+ for device_id in device_ids:
+ local_query.append((user_id, device_id))
+ else:
+ remote_queries.setdefault(user.domain, {})[user_id] = list(
+ device_ids
+ )
+ results = yield self.store.get_e2e_device_keys(local_query)
+
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, json_bytes in device_keys.items():
+ json_result.setdefault(user_id, {})[device_id] = json.loads(
+ json_bytes
+ )
+
+ for destination, device_keys in remote_queries.items():
+ remote_result = yield self.federation.query_client_keys(
+ destination, {"device_keys": device_keys}
+ )
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+ defer.returnValue((200, {"device_keys": json_result}))
+
+
+class OneTimeKeyServlet(RestServlet):
+ """
+ GET /keys/claim/<user-id>/<device-id>/<algorithm> HTTP/1.1
+
+ POST /keys/claim HTTP/1.1
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": "<algorithm>"
+ } } }
+
+ HTTP/1.1 200 OK
+ {
+ "one_time_keys": {
+ "<user_id>": {
+ "<device_id>": {
+ "<algorithm>:<key_id>": "<key_base64>"
+ } } } }
+
+ """
+ PATTERN = client_v2_pattern(
+ "/keys/claim(?:/?|(?:/"
+ "(?P<user_id>[^/]*)/(?P<device_id>[^/]*)/(?P<algorithm>[^/]*)"
+ ")?)"
+ )
+
+ def __init__(self, hs):
+ super(OneTimeKeyServlet, self).__init__()
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.clock = hs.get_clock()
+ self.federation = hs.get_replication_layer()
+ self.is_mine = hs.is_mine
+
+ @defer.inlineCallbacks
+ def on_GET(self, request, user_id, device_id, algorithm):
+ yield self.auth.get_user_by_req(request)
+ result = yield self.handle_request(
+ {"one_time_keys": {user_id: {device_id: algorithm}}}
+ )
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, user_id, device_id, algorithm):
+ yield self.auth.get_user_by_req(request)
+ try:
+ body = json.loads(request.content.read())
+ except:
+ raise SynapseError(400, "Invalid key JSON")
+ result = yield self.handle_request(body)
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def handle_request(self, body):
+ local_query = []
+ remote_queries = {}
+ for user_id, device_keys in body.get("one_time_keys", {}).items():
+ user = UserID.from_string(user_id)
+ if self.is_mine(user):
+ for device_id, algorithm in device_keys.items():
+ local_query.append((user_id, device_id, algorithm))
+ else:
+ remote_queries.setdefault(user.domain, {})[user_id] = (
+ device_keys
+ )
+ results = yield self.store.claim_e2e_one_time_keys(local_query)
+
+ json_result = {}
+ for user_id, device_keys in results.items():
+ for device_id, keys in device_keys.items():
+ for key_id, json_bytes in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {
+ key_id: json.loads(json_bytes)
+ }
+
+ for destination, device_keys in remote_queries.items():
+ remote_result = yield self.federation.claim_client_keys(
+ destination, {"one_time_keys": device_keys}
+ )
+ for user_id, keys in remote_result["one_time_keys"].items():
+ if user_id in device_keys:
+ json_result[user_id] = keys
+
+ defer.returnValue((200, {"one_time_keys": json_result}))
+
+
+def register_servlets(hs, http_server):
+ KeyUploadServlet(hs).register(http_server)
+ KeyQueryServlet(hs).register(http_server)
+ OneTimeKeyServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py
new file mode 100644
index 0000000000..40406e2ede
--- /dev/null
+++ b/synapse/rest/client/v2_alpha/receipts.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.http.servlet import RestServlet
+from ._base import client_v2_pattern
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class ReceiptRestServlet(RestServlet):
+ PATTERN = client_v2_pattern(
+ "/rooms/(?P<room_id>[^/]*)"
+ "/receipt/(?P<receipt_type>[^/]*)"
+ "/(?P<event_id>[^/]*)$"
+ )
+
+ def __init__(self, hs):
+ super(ReceiptRestServlet, self).__init__()
+ self.hs = hs
+ self.auth = hs.get_auth()
+ self.receipts_handler = hs.get_handlers().receipts_handler
+
+ @defer.inlineCallbacks
+ def on_POST(self, request, room_id, receipt_type, event_id):
+ user, client = yield self.auth.get_user_by_req(request)
+
+ yield self.receipts_handler.received_client_receipt(
+ room_id,
+ receipt_type,
+ user_id=user.to_string(),
+ event_id=event_id
+ )
+
+ defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+ ReceiptRestServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 72dfb876c5..1ba2f29711 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -19,7 +19,7 @@ from synapse.api.constants import LoginType
from synapse.api.errors import SynapseError, Codes
from synapse.http.servlet import RestServlet
-from ._base import client_v2_pattern, parse_request_allow_empty
+from ._base import client_v2_pattern, parse_json_dict_from_request
import logging
import hmac
@@ -50,26 +50,64 @@ class RegisterRestServlet(RestServlet):
self.auth_handler = hs.get_handlers().auth_handler
self.registration_handler = hs.get_handlers().registration_handler
self.identity_handler = hs.get_handlers().identity_handler
- self.login_handler = hs.get_handlers().login_handler
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
- body = parse_request_allow_empty(request)
- if 'password' not in body:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
+ if '/register/email/requestToken' in request.path:
+ ret = yield self.onEmailTokenRequest(request)
+ defer.returnValue(ret)
+ body = parse_json_dict_from_request(request)
+
+ # we do basic sanity checks here because the auth layer will store these
+ # in sessions. Pull out the username/password provided to us.
+ desired_password = None
+ if 'password' in body:
+ if (not isinstance(body['password'], basestring) or
+ len(body['password']) > 512):
+ raise SynapseError(400, "Invalid password")
+ desired_password = body["password"]
+
+ desired_username = None
if 'username' in body:
+ if (not isinstance(body['username'], basestring) or
+ len(body['username']) > 512):
+ raise SynapseError(400, "Invalid username")
desired_username = body['username']
- yield self.registration_handler.check_username(desired_username)
- is_using_shared_secret = False
- is_application_server = False
-
- service = None
+ appservice = None
if 'access_token' in request.args:
- service = yield self.auth.get_appservice_by_req(request)
+ appservice = yield self.auth.get_appservice_by_req(request)
+
+ # fork off as soon as possible for ASes and shared secret auth which
+ # have completely different registration flows to normal users
+
+ # == Application Service Registration ==
+ if appservice:
+ result = yield self._do_appservice_registration(
+ desired_username, request.args["access_token"][0]
+ )
+ defer.returnValue((200, result)) # we throw for non 200 responses
+ return
+
+ # == Shared Secret Registration == (e.g. create new user scripts)
+ if 'mac' in body:
+ # FIXME: Should we really be determining if this is shared secret
+ # auth based purely on the 'mac' key?
+ result = yield self._do_shared_secret_registration(
+ desired_username, desired_password, body["mac"]
+ )
+ defer.returnValue((200, result)) # we throw for non 200 responses
+ return
+
+ # == Normal User Registration == (everyone else)
+ if self.hs.config.disable_registration:
+ raise SynapseError(403, "Registration has been disabled")
+
+ if desired_username is not None:
+ yield self.registration_handler.check_username(desired_username)
if self.hs.config.enable_registration_captcha:
flows = [
@@ -82,39 +120,20 @@ class RegisterRestServlet(RestServlet):
[LoginType.EMAIL_IDENTITY]
]
- result = None
- if service:
- is_application_server = True
- params = body
- elif 'mac' in body:
- # Check registration-specific shared secret auth
- if 'username' not in body:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
- self._check_shared_secret_auth(
- body['username'], body['mac']
- )
- is_using_shared_secret = True
- params = body
- else:
- authed, result, params = yield self.auth_handler.check_auth(
- flows, body, self.hs.get_ip_from_request(request)
- )
-
- if not authed:
- defer.returnValue((401, result))
-
- can_register = (
- not self.hs.config.disable_registration
- or is_application_server
- or is_using_shared_secret
+ authed, result, params = yield self.auth_handler.check_auth(
+ flows, body, self.hs.get_ip_from_request(request)
)
- if not can_register:
- raise SynapseError(403, "Registration has been disabled")
+ if not authed:
+ defer.returnValue((401, result))
+ return
+
+ # NB: This may be from the auth handler and NOT from the POST
if 'password' not in params:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
- desired_username = params['username'] if 'username' in params else None
- new_password = params['password']
+ raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
+
+ desired_username = params.get("username", None)
+ new_password = params.get("password", None)
(user_id, token) = yield self.registration_handler.register(
localpart=desired_username,
@@ -128,7 +147,7 @@ class RegisterRestServlet(RestServlet):
if reqd not in threepid:
logger.info("Can't add incomplete 3pid")
else:
- yield self.login_handler.add_threepid(
+ yield self.auth_handler.add_threepid(
user_id,
threepid['medium'],
threepid['address'],
@@ -147,18 +166,21 @@ class RegisterRestServlet(RestServlet):
else:
logger.info("bind_email not specified: not binding email")
- result = {
- "user_id": user_id,
- "access_token": token,
- "home_server": self.hs.hostname,
- }
-
+ result = self._create_registration_details(user_id, token)
defer.returnValue((200, result))
def on_OPTIONS(self, _):
return 200, {}
- def _check_shared_secret_auth(self, username, mac):
+ @defer.inlineCallbacks
+ def _do_appservice_registration(self, username, as_token):
+ (user_id, token) = yield self.registration_handler.appservice_register(
+ username, as_token
+ )
+ defer.returnValue(self._create_registration_details(user_id, token))
+
+ @defer.inlineCallbacks
+ def _do_shared_secret_registration(self, username, password, mac):
if not self.hs.config.registration_shared_secret:
raise SynapseError(400, "Shared secret registration is not enabled")
@@ -174,13 +196,46 @@ class RegisterRestServlet(RestServlet):
digestmod=sha1,
).hexdigest()
- if compare_digest(want_mac, got_mac):
- return True
- else:
+ if not compare_digest(want_mac, got_mac):
raise SynapseError(
403, "HMAC incorrect",
)
+ (user_id, token) = yield self.registration_handler.register(
+ localpart=username, password=password
+ )
+ defer.returnValue(self._create_registration_details(user_id, token))
+
+ def _create_registration_details(self, user_id, token):
+ return {
+ "user_id": user_id,
+ "access_token": token,
+ "home_server": self.hs.hostname,
+ }
+
+ @defer.inlineCallbacks
+ def onEmailTokenRequest(self, request):
+ body = parse_json_dict_from_request(request)
+
+ required = ['id_server', 'client_secret', 'email', 'send_attempt']
+ absent = []
+ for k in required:
+ if k not in body:
+ absent.append(k)
+
+ if len(absent) > 0:
+ raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
+
+ existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
+ 'email', body['email']
+ )
+
+ if existingUid is not None:
+ raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
+
+ ret = yield self.identity_handler.requestEmailToken(**body)
+ defer.returnValue((200, ret))
+
def register_servlets(hs, http_server):
RegisterRestServlet(hs).register(http_server)
diff --git a/synapse/rest/media/v1/base_resource.py b/synapse/rest/media/v1/base_resource.py
index 6c83a9478c..b2aeb8c909 100644
--- a/synapse/rest/media/v1/base_resource.py
+++ b/synapse/rest/media/v1/base_resource.py
@@ -27,18 +27,30 @@ from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
+from synapse.util.stringutils import is_ascii
import os
+import cgi
import logging
+import urllib
+import urlparse
logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
- server_name, media_id = request.postpath
- return (server_name, media_id)
+ # This allows users to append e.g. /test.png to the URL. Useful for
+ # clients that parse the URL to see content type.
+ server_name, media_id = request.postpath[:2]
+ file_name = None
+ if len(request.postpath) > 2:
+ try:
+ file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
+ except UnicodeDecodeError:
+ pass
+ return server_name, media_id, file_name
except:
raise SynapseError(
404,
@@ -62,6 +74,8 @@ class BaseMediaResource(Resource):
self.filepaths = filepaths
self.version_string = hs.version_string
self.downloads = {}
+ self.dynamic_thumbnails = hs.config.dynamic_thumbnails
+ self.thumbnail_requirements = hs.config.thumbnail_requirements
def _respond_404(self, request):
respond_with_json(
@@ -128,12 +142,38 @@ class BaseMediaResource(Resource):
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
+ content_disposition = headers.get("Content-Disposition", None)
+ if content_disposition:
+ _, params = cgi.parse_header(content_disposition[0],)
+ upload_name = None
+
+ # First check if there is a valid UTF-8 filename
+ upload_name_utf8 = params.get("filename*", None)
+ if upload_name_utf8:
+ if upload_name_utf8.lower().startswith("utf-8''"):
+ upload_name = upload_name_utf8[7:]
+
+ # If there isn't check for an ascii name.
+ if not upload_name:
+ upload_name_ascii = params.get("filename", None)
+ if upload_name_ascii and is_ascii(upload_name_ascii):
+ upload_name = upload_name_ascii
+
+ if upload_name:
+ upload_name = urlparse.unquote(upload_name)
+ try:
+ upload_name = upload_name.decode("utf-8")
+ except UnicodeDecodeError:
+ upload_name = None
+ else:
+ upload_name = None
+
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
- upload_name=None,
+ upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
@@ -144,7 +184,7 @@ class BaseMediaResource(Resource):
media_info = {
"media_type": media_type,
"media_length": length,
- "upload_name": None,
+ "upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
@@ -157,11 +197,26 @@ class BaseMediaResource(Resource):
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path,
- file_size=None):
+ file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
+ if upload_name:
+ if is_ascii(upload_name):
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename=%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
+ else:
+ request.setHeader(
+ b"Content-Disposition",
+ b"inline; filename*=utf-8''%s" % (
+ urllib.quote(upload_name.encode("utf-8")),
+ ),
+ )
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
@@ -187,22 +242,74 @@ class BaseMediaResource(Resource):
self._respond_404(request)
def _get_thumbnail_requirements(self, media_type):
- if media_type == "image/jpeg":
- return (
- (32, 32, "crop", "image/jpeg"),
- (96, 96, "crop", "image/jpeg"),
- (320, 240, "scale", "image/jpeg"),
- (640, 480, "scale", "image/jpeg"),
- )
- elif (media_type == "image/png") or (media_type == "image/gif"):
- return (
- (32, 32, "crop", "image/png"),
- (96, 96, "crop", "image/png"),
- (320, 240, "scale", "image/png"),
- (640, 480, "scale", "image/png"),
+ return self.thumbnail_requirements.get(media_type, ())
+
+ def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
+ t_method, t_type):
+ thumbnailer = Thumbnailer(input_path)
+ m_width = thumbnailer.width
+ m_height = thumbnailer.height
+
+ if m_width * m_height >= self.max_image_pixels:
+ logger.info(
+ "Image too large to thumbnail %r x %r > %r",
+ m_width, m_height, self.max_image_pixels
)
+ return
+
+ if t_method == "crop":
+ t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ elif t_method == "scale":
+ t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
- return ()
+ t_len = None
+
+ return t_len
+
+ @defer.inlineCallbacks
+ def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
+ t_method, t_type):
+ input_path = self.filepaths.local_media_filepath(media_id)
+
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+
+ t_len = yield threads.deferToThread(
+ self._generate_thumbnail,
+ input_path, t_path, t_width, t_height, t_method, t_type
+ )
+
+ if t_len:
+ yield self.store.store_local_thumbnail(
+ media_id, t_width, t_height, t_type, t_method, t_len
+ )
+
+ defer.returnValue(t_path)
+
+ @defer.inlineCallbacks
+ def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
+ t_width, t_height, t_method, t_type):
+ input_path = self.filepaths.remote_media_filepath(server_name, file_id)
+
+ t_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+
+ t_len = yield threads.deferToThread(
+ self._generate_thumbnail,
+ input_path, t_path, t_width, t_height, t_method, t_type
+ )
+
+ if t_len:
+ yield self.store.store_remote_media_thumbnail(
+ server_name, media_id, file_id,
+ t_width, t_height, t_type, t_method, t_len
+ )
+
+ defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
@@ -223,43 +330,52 @@ class BaseMediaResource(Resource):
)
return
- scales = set()
- crops = set()
- for r_width, r_height, r_method, r_type in requirements:
- if r_method == "scale":
- t_width, t_height = thumbnailer.aspect(r_width, r_height)
- scales.add((
- min(m_width, t_width), min(m_height, t_height), r_type,
+ local_thumbnails = []
+
+ def generate_thumbnails():
+ scales = set()
+ crops = set()
+ for r_width, r_height, r_method, r_type in requirements:
+ if r_method == "scale":
+ t_width, t_height = thumbnailer.aspect(r_width, r_height)
+ scales.add((
+ min(m_width, t_width), min(m_height, t_height), r_type,
+ ))
+ elif r_method == "crop":
+ crops.add((r_width, r_height, r_type))
+
+ for t_width, t_height, t_type in scales:
+ t_method = "scale"
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
+
+ local_thumbnails.append((
+ media_id, t_width, t_height, t_type, t_method, t_len
))
- elif r_method == "crop":
- crops.add((r_width, r_height, r_type))
- for t_width, t_height, t_type in scales:
- t_method = "scale"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
- yield self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ for t_width, t_height, t_type in crops:
+ if (t_width, t_height, t_type) in scales:
+ # If the aspect ratio of the cropped thumbnail matches a purely
+ # scaled one then there is no point in calculating a separate
+ # thumbnail.
+ continue
+ t_method = "crop"
+ t_path = self.filepaths.local_media_thumbnail(
+ media_id, t_width, t_height, t_type, t_method
+ )
+ self._makedirs(t_path)
+ t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
+ local_thumbnails.append((
+ media_id, t_width, t_height, t_type, t_method, t_len
+ ))
- for t_width, t_height, t_type in crops:
- if (t_width, t_height, t_type) in scales:
- # If the aspect ratio of the cropped thumbnail matches a purely
- # scaled one then there is no point in calculating a separate
- # thumbnail.
- continue
- t_method = "crop"
- t_path = self.filepaths.local_media_thumbnail(
- media_id, t_width, t_height, t_type, t_method
- )
- self._makedirs(t_path)
- t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
- yield self.store.store_local_thumbnail(
- media_id, t_width, t_height, t_type, t_method, t_len
- )
+ yield threads.deferToThread(generate_thumbnails)
+
+ for l in local_thumbnails:
+ yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py
index 0fe6abf647..ab384e5388 100644
--- a/synapse/rest/media/v1/download_resource.py
+++ b/synapse/rest/media/v1/download_resource.py
@@ -32,14 +32,16 @@ class DownloadResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- server_name, media_id = parse_media_id(request)
+ server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
- yield self._respond_local_file(request, media_id)
+ yield self._respond_local_file(request, media_id, name)
else:
- yield self._respond_remote_file(request, server_name, media_id)
+ yield self._respond_remote_file(
+ request, server_name, media_id, name
+ )
@defer.inlineCallbacks
- def _respond_local_file(self, request, media_id):
+ def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
@@ -47,24 +49,28 @@ class DownloadResource(BaseMediaResource):
media_type = media_info["media_type"]
media_length = media_info["media_length"]
+ upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(
- request, media_type, file_path, media_length
+ request, media_type, file_path, media_length,
+ upload_name=upload_name,
)
@defer.inlineCallbacks
- def _respond_remote_file(self, request, server_name, media_id):
+ def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
filesystem_id = media_info["filesystem_id"]
+ upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.remote_media_filepath(
server_name, filesystem_id
)
yield self._respond_with_file(
- request, media_type, file_path, media_length
+ request, media_type, file_path, media_length,
+ upload_name=upload_name,
)
diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py
index 1dadd880b2..e506dad934 100644
--- a/synapse/rest/media/v1/thumbnail_resource.py
+++ b/synapse/rest/media/v1/thumbnail_resource.py
@@ -36,21 +36,32 @@ class ThumbnailResource(BaseMediaResource):
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
- server_name, media_id = parse_media_id(request)
+ server_name, media_id, _ = parse_media_id(request)
width = parse_integer(request, "width")
height = parse_integer(request, "height")
method = parse_string(request, "method", "scale")
m_type = parse_string(request, "type", "image/png")
if server_name == self.server_name:
- yield self._respond_local_thumbnail(
- request, media_id, width, height, method, m_type
- )
+ if self.dynamic_thumbnails:
+ yield self._select_or_generate_local_thumbnail(
+ request, media_id, width, height, method, m_type
+ )
+ else:
+ yield self._respond_local_thumbnail(
+ request, media_id, width, height, method, m_type
+ )
else:
- yield self._respond_remote_thumbnail(
- request, server_name, media_id,
- width, height, method, m_type
- )
+ if self.dynamic_thumbnails:
+ yield self._select_or_generate_remote_thumbnail(
+ request, server_name, media_id,
+ width, height, method, m_type
+ )
+ else:
+ yield self._respond_remote_thumbnail(
+ request, server_name, media_id,
+ width, height, method, m_type
+ )
@defer.inlineCallbacks
def _respond_local_thumbnail(self, request, media_id, width, height,
@@ -83,6 +94,87 @@ class ThumbnailResource(BaseMediaResource):
)
@defer.inlineCallbacks
+ def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
+ desired_height, desired_method,
+ desired_type):
+ media_info = yield self.store.get_local_media(media_id)
+
+ if not media_info:
+ self._respond_404(request)
+ return
+
+ thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
+ for info in thumbnail_infos:
+ t_w = info["thumbnail_width"] == desired_width
+ t_h = info["thumbnail_height"] == desired_height
+ t_method = info["thumbnail_method"] == desired_method
+ t_type = info["thumbnail_type"] == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_path = self.filepaths.local_media_thumbnail(
+ media_id, desired_width, desired_height, desired_type, desired_method,
+ )
+ yield self._respond_with_file(request, desired_type, file_path)
+ return
+
+ logger.debug("We don't have a local thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = yield self._generate_local_exact_thumbnail(
+ media_id, desired_width, desired_height, desired_method, desired_type
+ )
+
+ if file_path:
+ yield self._respond_with_file(request, desired_type, file_path)
+ else:
+ yield self._respond_default_thumbnail(
+ request, media_info, desired_width, desired_height,
+ desired_method, desired_type,
+ )
+
+ @defer.inlineCallbacks
+ def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
+ desired_width, desired_height,
+ desired_method, desired_type):
+ media_info = yield self._get_remote_media(server_name, media_id)
+
+ thumbnail_infos = yield self.store.get_remote_media_thumbnails(
+ server_name, media_id,
+ )
+
+ file_id = media_info["filesystem_id"]
+
+ for info in thumbnail_infos:
+ t_w = info["thumbnail_width"] == desired_width
+ t_h = info["thumbnail_height"] == desired_height
+ t_method = info["thumbnail_method"] == desired_method
+ t_type = info["thumbnail_type"] == desired_type
+
+ if t_w and t_h and t_method and t_type:
+ file_path = self.filepaths.remote_media_thumbnail(
+ server_name, file_id, desired_width, desired_height,
+ desired_type, desired_method,
+ )
+ yield self._respond_with_file(request, desired_type, file_path)
+ return
+
+ logger.debug("We don't have a local thumbnail of that size. Generating")
+
+ # Okay, so we generate one.
+ file_path = yield self._generate_remote_exact_thumbnail(
+ server_name, file_id, media_id, desired_width,
+ desired_height, desired_method, desired_type
+ )
+
+ if file_path:
+ yield self._respond_with_file(request, desired_type, file_path)
+ else:
+ yield self._respond_default_thumbnail(
+ request, media_info, desired_width, desired_height,
+ desired_method, desired_type,
+ )
+
+ @defer.inlineCallbacks
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
height, method, m_type):
# TODO: Don't download the whole remote file
@@ -162,11 +254,12 @@ class ThumbnailResource(BaseMediaResource):
t_method = info["thumbnail_method"]
if t_method == "scale" or t_method == "crop":
aspect_quality = abs(d_w * t_h - d_h * t_w)
+ min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
info_list.append((
- aspect_quality, size_quality, type_quality,
+ aspect_quality, min_quality, size_quality, type_quality,
length_quality, info
))
if info_list:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 28404f2b7b..1e965c363a 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -82,7 +82,7 @@ class Thumbnailer(object):
def save_image(self, output_image, output_type, output_path):
output_bytes_io = BytesIO()
- output_image.save(output_bytes_io, self.FORMATS[output_type], quality=70)
+ output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index cc571976a5..031bfa80f8 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -84,6 +84,16 @@ class UploadResource(BaseMediaResource):
code=413,
)
+ upload_name = request.args.get("filename", None)
+ if upload_name:
+ try:
+ upload_name = upload_name[0].decode('UTF-8')
+ except UnicodeDecodeError:
+ raise SynapseError(
+ msg="Invalid UTF-8 filename parameter: %r" % (upload_name),
+ code=400,
+ )
+
headers = request.requestHeaders
if headers.hasHeader("Content-Type"):
@@ -99,7 +109,7 @@ class UploadResource(BaseMediaResource):
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
- media_type, None, request.content.read(),
+ media_type, upload_name, request.content.read(),
content_length, auth_user
)
|