diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py
index 41534b8c2a..82433a2aa9 100644
--- a/synapse/rest/client/v1/admin.py
+++ b/synapse/rest/client/v1/admin.py
@@ -23,7 +23,7 @@ from six.moves import http_client
from twisted.internet import defer
-from synapse.api.constants import Membership
+from synapse.api.constants import Membership, UserTypes
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
assert_params_in_dict,
@@ -158,6 +158,11 @@ class UserRegisterServlet(ClientV1RestServlet):
raise SynapseError(400, "Invalid password")
admin = body.get("admin", None)
+ user_type = body.get("user_type", None)
+
+ if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
+ raise SynapseError(400, "Invalid user type")
+
got_mac = body["mac"]
want_mac = hmac.new(
@@ -171,6 +176,9 @@ class UserRegisterServlet(ClientV1RestServlet):
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
+ if user_type:
+ want_mac.update(b"\x00")
+ want_mac.update(user_type.encode('utf8'))
want_mac = want_mac.hexdigest()
if not hmac.compare_digest(
@@ -189,6 +197,7 @@ class UserRegisterServlet(ClientV1RestServlet):
password=body["password"],
admin=bool(admin),
generate_token=False,
+ user_type=user_type,
)
result = yield register._create_registration_details(user_id, body)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 0010699d31..6121c5b6df 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -18,17 +18,18 @@ import xml.etree.ElementTree as ET
from six.moves import urllib
-from canonicaljson import json
-from saml2 import BINDING_HTTP_POST, config
-from saml2.client import Saml2Client
-
from twisted.internet import defer
from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.http.server import finish_request
-from synapse.http.servlet import parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+ RestServlet,
+ parse_json_object_from_request,
+ parse_string,
+)
+from synapse.rest.well_known import WellKnownBuilder
+from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util.msisdn import phone_number_to_msisdn
from .base import ClientV1RestServlet, client_path_patterns
@@ -81,30 +82,31 @@ def login_id_thirdparty_from_phone(identifier):
class LoginRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/login$")
- SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
+ SSO_TYPE = "m.login.sso"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
- self.idp_redirect_url = hs.config.saml2_idp_redirect_url
- self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.auth_handler = self.hs.get_auth_handler()
- self.device_handler = self.hs.get_device_handler()
+ self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
+ self._well_known_builder = WellKnownBuilder(hs)
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+
+ # we advertise CAS for backwards compat, though MSC1721 renamed it
+ # to SSO.
flows.append({"type": LoginRestServlet.CAS_TYPE})
# While its valid for us to advertise this login type generally,
@@ -129,29 +131,21 @@ class LoginRestServlet(ClientV1RestServlet):
def on_POST(self, request):
login_submission = parse_json_object_from_request(request)
try:
- if self.saml2_enabled and (login_submission["type"] ==
- LoginRestServlet.SAML2_TYPE):
- relay_state = ""
- if "relay_state" in login_submission:
- relay_state = "&RelayState=" + urllib.parse.quote(
- login_submission["relay_state"])
- result = {
- "uri": "%s%s" % (self.idp_redirect_url, relay_state)
- }
- defer.returnValue((200, result))
- elif self.jwt_enabled and (login_submission["type"] ==
- LoginRestServlet.JWT_TYPE):
+ if self.jwt_enabled and (login_submission["type"] ==
+ LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
- defer.returnValue(result)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
result = yield self.do_token_login(login_submission)
- defer.returnValue(result)
else:
result = yield self._do_other_login(login_submission)
- defer.returnValue(result)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
+ well_known_data = self._well_known_builder.get_well_known()
+ if well_known_data:
+ result["well_known"] = well_known_data
+ defer.returnValue((200, result))
+
@defer.inlineCallbacks
def _do_other_login(self, login_submission):
"""Handle non-token/saml/jwt logins
@@ -160,7 +154,7 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission:
Returns:
- (int, object): HTTP code/response
+ dict: HTTP response
"""
# Log the request we got, but only certain fields to minimise the chance of
# logging someone's password (even if they accidentally put it in the wrong
@@ -226,11 +220,10 @@ class LoginRestServlet(ClientV1RestServlet):
login_submission,
)
- device_id = yield self._register_device(
- canonical_user_id, login_submission,
- )
- access_token = yield auth_handler.get_access_token_for_user_id(
- canonical_user_id, device_id,
+ device_id = login_submission.get("device_id")
+ initial_display_name = login_submission.get("initial_device_display_name")
+ device_id, access_token = yield self.registration_handler.register_device(
+ canonical_user_id, device_id, initial_display_name,
)
result = {
@@ -243,7 +236,7 @@ class LoginRestServlet(ClientV1RestServlet):
if callback is not None:
yield callback(result)
- defer.returnValue((200, result))
+ defer.returnValue(result)
@defer.inlineCallbacks
def do_token_login(self, login_submission):
@@ -252,10 +245,13 @@ class LoginRestServlet(ClientV1RestServlet):
user_id = (
yield auth_handler.validate_short_term_login_token_and_get_user_id(token)
)
- device_id = yield self._register_device(user_id, login_submission)
- access_token = yield auth_handler.get_access_token_for_user_id(
- user_id, device_id,
+
+ device_id = login_submission.get("device_id")
+ initial_display_name = login_submission.get("initial_device_display_name")
+ device_id, access_token = yield self.registration_handler.register_device(
+ user_id, device_id, initial_display_name,
)
+
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
@@ -263,7 +259,7 @@ class LoginRestServlet(ClientV1RestServlet):
"device_id": device_id,
}
- defer.returnValue((200, result))
+ defer.returnValue(result)
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
@@ -292,11 +288,10 @@ class LoginRestServlet(ClientV1RestServlet):
auth_handler = self.auth_handler
registered_user_id = yield auth_handler.check_user_exists(user_id)
if registered_user_id:
- device_id = yield self._register_device(
- registered_user_id, login_submission
- )
- access_token = yield auth_handler.get_access_token_for_user_id(
- registered_user_id, device_id,
+ device_id = login_submission.get("device_id")
+ initial_display_name = login_submission.get("initial_device_display_name")
+ device_id, access_token = yield self.registration_handler.register_device(
+ registered_user_id, device_id, initial_display_name,
)
result = {
@@ -305,90 +300,30 @@ class LoginRestServlet(ClientV1RestServlet):
"home_server": self.hs.hostname,
}
else:
- # TODO: we should probably check that the register isn't going
- # to fonx/change our user_id before registering the device
- device_id = yield self._register_device(user_id, login_submission)
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
+
+ device_id = login_submission.get("device_id")
+ initial_display_name = login_submission.get("initial_device_display_name")
+ device_id, access_token = yield self.registration_handler.register_device(
+ registered_user_id, device_id, initial_display_name,
+ )
+
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
- defer.returnValue((200, result))
-
- def _register_device(self, user_id, login_submission):
- """Register a device for a user.
-
- This is called after the user's credentials have been validated, but
- before the access token has been issued.
-
- Args:
- (str) user_id: full canonical @user:id
- (object) login_submission: dictionary supplied to /login call, from
- which we pull device_id and initial_device_name
- Returns:
- defer.Deferred: (str) device_id
- """
- device_id = login_submission.get("device_id")
- initial_display_name = login_submission.get(
- "initial_device_display_name")
- return self.device_handler.check_device_registered(
- user_id, device_id, initial_display_name
- )
+ defer.returnValue(result)
-class SAML2RestServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/saml2", releases=())
+class CasRedirectServlet(RestServlet):
+ PATTERNS = client_path_patterns("/login/(cas|sso)/redirect")
def __init__(self, hs):
- super(SAML2RestServlet, self).__init__(hs)
- self.sp_config = hs.config.saml2_config_path
- self.handlers = hs.get_handlers()
-
- @defer.inlineCallbacks
- def on_POST(self, request):
- saml2_auth = None
- try:
- conf = config.SPConfig()
- conf.load_file(self.sp_config)
- SP = Saml2Client(conf)
- saml2_auth = SP.parse_authn_request_response(
- request.args['SAMLResponse'][0], BINDING_HTTP_POST)
- except Exception as e: # Not authenticated
- logger.exception(e)
- if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
- username = saml2_auth.name_id.text
- handler = self.handlers.registration_handler
- (user_id, token) = yield handler.register_saml2(username)
- # Forward to the RelayState callback along with ava
- if 'RelayState' in request.args:
- request.redirect(urllib.parse.unquote(
- request.args['RelayState'][0]) +
- '?status=authenticated&access_token=' +
- token + '&user_id=' + user_id + '&ava=' +
- urllib.quote(json.dumps(saml2_auth.ava)))
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "authenticated",
- "user_id": user_id, "token": token,
- "ava": saml2_auth.ava}))
- elif 'RelayState' in request.args:
- request.redirect(urllib.parse.unquote(
- request.args['RelayState'][0]) +
- '?status=not_authenticated')
- finish_request(request)
- defer.returnValue(None)
- defer.returnValue((200, {"status": "not_authenticated"}))
-
-
-class CasRedirectServlet(ClientV1RestServlet):
- PATTERNS = client_path_patterns("/login/cas/redirect", releases=())
-
- def __init__(self, hs):
- super(CasRedirectServlet, self).__init__(hs)
+ super(CasRedirectServlet, self).__init__()
self.cas_server_url = hs.config.cas_server_url.encode('ascii')
self.cas_service_url = hs.config.cas_service_url.encode('ascii')
@@ -416,17 +351,15 @@ class CasTicketServlet(ClientV1RestServlet):
self.cas_server_url = hs.config.cas_server_url
self.cas_service_url = hs.config.cas_service_url
self.cas_required_attributes = hs.config.cas_required_attributes
- self.auth_handler = hs.get_auth_handler()
- self.handlers = hs.get_handlers()
- self.macaroon_gen = hs.get_macaroon_generator()
+ self._sso_auth_handler = SSOAuthHandler(hs)
@defer.inlineCallbacks
def on_GET(self, request):
- client_redirect_url = request.args[b"redirectUrl"][0]
+ client_redirect_url = parse_string(request, "redirectUrl", required=True)
http_client = self.hs.get_simple_http_client()
uri = self.cas_server_url + "/proxyValidate"
args = {
- "ticket": request.args[b"ticket"][0].decode('ascii'),
+ "ticket": parse_string(request, "ticket", required=True),
"service": self.cas_service_url
}
try:
@@ -438,7 +371,6 @@ class CasTicketServlet(ClientV1RestServlet):
result = yield self.handle_cas_response(request, body, client_redirect_url)
defer.returnValue(result)
- @defer.inlineCallbacks
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
user, attributes = self.parse_cas_response(cas_response_body)
@@ -454,28 +386,9 @@ class CasTicketServlet(ClientV1RestServlet):
if required_value != actual_value:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
- user_id = UserID(user, self.hs.hostname).to_string()
- auth_handler = self.auth_handler
- registered_user_id = yield auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id, _ = (
- yield self.handlers.registration_handler.register(localpart=user)
- )
-
- login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ return self._sso_auth_handler.on_successful_auth(
+ user, request, client_redirect_url,
)
- redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
- login_token)
- request.redirect(redirect_url)
- finish_request(request)
-
- def add_login_token_to_redirect_url(self, url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
- return urllib.parse.urlunparse(url_parts)
def parse_cas_response(self, cas_response_body):
user = None
@@ -510,10 +423,78 @@ class CasTicketServlet(ClientV1RestServlet):
return user, attributes
+class SSOAuthHandler(object):
+ """
+ Utility class for Resources and Servlets which handle the response from a SSO
+ service
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+ def __init__(self, hs):
+ self._hostname = hs.hostname
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+ self._macaroon_gen = hs.get_macaroon_generator()
+
+ @defer.inlineCallbacks
+ def on_successful_auth(
+ self, username, request, client_redirect_url,
+ user_display_name=None,
+ ):
+ """Called once the user has successfully authenticated with the SSO.
+
+ Registers the user if necessary, and then returns a redirect (with
+ a login token) to the client.
+
+ Args:
+ username (unicode|bytes): the remote user id. We'll map this onto
+ something sane for a MXID localpath.
+
+ request (SynapseRequest): the incoming request from the browser. We'll
+ respond to it with a redirect.
+
+ client_redirect_url (unicode): the redirect_url the client gave us when
+ it first started the process.
+
+ user_display_name (unicode|None): if set, and we have to register a new user,
+ we will set their displayname to this.
+
+ Returns:
+ Deferred[none]: Completes once we have handled the request.
+ """
+ localpart = map_username_to_mxid_localpart(username)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+ if not registered_user_id:
+ registered_user_id, _ = (
+ yield self._registration_handler.register(
+ localpart=localpart,
+ generate_token=False,
+ default_display_name=user_display_name,
+ )
+ )
+
+ login_token = self._macaroon_gen.generate_short_term_login_token(
+ registered_user_id
+ )
+ redirect_url = self._add_login_token_to_redirect_url(
+ client_redirect_url, login_token
+ )
+ request.redirect(redirect_url)
+ finish_request(request)
+
+ @staticmethod
+ def _add_login_token_to_redirect_url(url, token):
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"loginToken": token})
+ url_parts[4] = urllib.parse.urlencode(query)
+ return urllib.parse.urlunparse(url_parts)
+
+
def register_servlets(hs, http_server):
LoginRestServlet(hs).register(http_server)
- if hs.config.saml2_enabled:
- SAML2RestServlet(hs).register(http_server)
if hs.config.cas_enabled:
CasRedirectServlet(hs).register(http_server)
CasTicketServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py
index 9382b1f124..c654f9b5f0 100644
--- a/synapse/rest/client/v1/push_rule.py
+++ b/synapse/rest/client/v1/push_rule.py
@@ -42,7 +42,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_PUT(self, request):
- spec = _rule_spec_from_path(request.postpath)
+ spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath])
try:
priority_class = _priority_class_from_spec(spec)
except InvalidRuleException as e:
@@ -103,7 +103,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_DELETE(self, request):
- spec = _rule_spec_from_path(request.postpath)
+ spec = _rule_spec_from_path([x.decode('utf8') for x in request.postpath])
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
@@ -134,7 +134,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
rules = format_push_rules_for_user(requester.user, rules)
- path = request.postpath[1:]
+ path = [x.decode('utf8') for x in request.postpath][1:]
if path == []:
# we're a reference impl: pedantry is our job.
@@ -142,11 +142,10 @@ class PushRuleRestServlet(ClientV1RestServlet):
PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
)
- if path[0] == b'':
+ if path[0] == '':
defer.returnValue((200, rules))
- elif path[0] == b'global':
- path = [x.decode('ascii') for x in path[1:]]
- result = _filter_ruleset_with_path(rules['global'], path)
+ elif path[0] == 'global':
+ result = _filter_ruleset_with_path(rules['global'], path[1:])
defer.returnValue((200, result))
else:
raise UnrecognizedRequestError()
@@ -190,12 +189,24 @@ class PushRuleRestServlet(ClientV1RestServlet):
def _rule_spec_from_path(path):
+ """Turn a sequence of path components into a rule spec
+
+ Args:
+ path (sequence[unicode]): the URL path components.
+
+ Returns:
+ dict: rule spec dict, containing scope/template/rule_id entries,
+ and possibly attr.
+
+ Raises:
+ UnrecognizedRequestError if the path components cannot be parsed.
+ """
if len(path) < 2:
raise UnrecognizedRequestError()
- if path[0] != b'pushrules':
+ if path[0] != 'pushrules':
raise UnrecognizedRequestError()
- scope = path[1].decode('ascii')
+ scope = path[1]
path = path[2:]
if scope != 'global':
raise UnrecognizedRequestError()
@@ -203,13 +214,13 @@ def _rule_spec_from_path(path):
if len(path) == 0:
raise UnrecognizedRequestError()
- template = path[0].decode('ascii')
+ template = path[0]
path = path[1:]
if len(path) == 0 or len(path[0]) == 0:
raise UnrecognizedRequestError()
- rule_id = path[0].decode('ascii')
+ rule_id = path[0]
spec = {
'scope': scope,
@@ -220,7 +231,7 @@ def _rule_spec_from_path(path):
path = path[1:]
if len(path) > 0 and len(path[0]) > 0:
- spec['attr'] = path[0].decode('ascii')
+ spec['attr'] = path[0]
return spec
diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py
index b84f0260f2..4c07ae7f45 100644
--- a/synapse/rest/client/v1/pusher.py
+++ b/synapse/rest/client/v1/pusher.py
@@ -142,7 +142,7 @@ class PushersRemoveRestServlet(RestServlet):
To allow pusher to be delete by clicking a link (ie. GET request)
"""
PATTERNS = client_path_patterns("/pushers/remove$")
- SUCCESS_HTML = "<html><body>You have been unsubscribed</body><html>"
+ SUCCESS_HTML = b"<html><body>You have been unsubscribed</body><html>"
def __init__(self, hs):
super(PushersRemoveRestServlet, self).__init__()
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index fcfe7857f6..48da4d557f 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -89,7 +89,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomStateEventRestServlet, self).__init__(hs)
self.handlers = hs.get_handlers()
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -172,7 +172,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
content=content,
)
else:
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
@@ -189,7 +189,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomSendEventRestServlet, self).__init__(hs)
- self.event_creation_hander = hs.get_event_creation_handler()
+ self.event_creation_handler = hs.get_event_creation_handler()
def register(self, http_server):
# /rooms/$roomid/send/$event_type[/$txn_id]
@@ -211,7 +211,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
if b'ts' in request.args and requester.app_service:
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
- event = yield self.event_creation_hander.create_and_send_nonmember_event(
+ event = yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
event_dict,
txn_id=txn_id,
|