diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d0d4999795..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,11 +14,6 @@
# limitations under the License.
import logging
-import xml.etree.ElementTree as ET
-
-from six.moves import urllib
-
-from twisted.web.client import PartialDownloadError
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -28,10 +23,10 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.push.mailer import load_jinja2_templates
+from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
-from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.types import UserID
from synapse.util.msisdn import phone_number_to_msisdn
logger = logging.getLogger(__name__)
@@ -88,6 +83,7 @@ class LoginRestServlet(RestServlet):
self.jwt_algorithm = hs.config.jwt_algorithm
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
+ self.oidc_enabled = hs.config.oidc_enabled
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
self.handlers = hs.get_handlers()
@@ -101,9 +97,7 @@ class LoginRestServlet(RestServlet):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
- if self.saml2_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
if self.cas_enabled:
flows.append({"type": LoginRestServlet.SSO_TYPE})
@@ -119,6 +113,11 @@ class LoginRestServlet(RestServlet):
# fall back to the fallback API if they don't understand one of the
# login flow types returned.
flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.saml2_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
+ flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+ elif self.oidc_enabled:
+ flows.append({"type": LoginRestServlet.SSO_TYPE})
flows.extend(
({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -402,24 +401,27 @@ class BaseSSORedirectServlet(RestServlet):
PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
- def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest):
args = request.args
if b"redirectUrl" not in args:
return 400, "Redirect URL not specified for SSO auth"
client_redirect_url = args[b"redirectUrl"][0]
- sso_url = self.get_sso_url(client_redirect_url)
+ sso_url = await self.get_sso_url(request, client_redirect_url)
request.redirect(sso_url)
finish_request(request)
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
"""Get the URL to redirect to, to perform SSO auth
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ request: The client request to redirect.
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
# to be implemented by subclasses
raise NotImplementedError()
@@ -427,19 +429,14 @@ class BaseSSORedirectServlet(RestServlet):
class CasRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
- super(CasRedirectServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url.encode("ascii")
- self.cas_service_url = hs.config.cas_service_url.encode("ascii")
+ self._cas_handler = hs.get_cas_handler()
- def get_sso_url(self, client_redirect_url):
- client_redirect_url_param = urllib.parse.urlencode(
- {b"redirectUrl": client_redirect_url}
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return self._cas_handler.get_redirect_url(
+ {"redirectUrl": client_redirect_url}
).encode("ascii")
- hs_redirect_url = self.cas_service_url + b"/_matrix/client/r0/login/cas/ticket"
- service_param = urllib.parse.urlencode(
- {b"service": b"%s?%s" % (hs_redirect_url, client_redirect_url_param)}
- ).encode("ascii")
- return b"%s/login?%s" % (self.cas_server_url, service_param)
class CasTicketServlet(RestServlet):
@@ -447,81 +444,25 @@ class CasTicketServlet(RestServlet):
def __init__(self, hs):
super(CasTicketServlet, self).__init__()
- self.cas_server_url = hs.config.cas_server_url
- self.cas_service_url = hs.config.cas_service_url
- self.cas_displayname_attribute = hs.config.cas_displayname_attribute
- self.cas_required_attributes = hs.config.cas_required_attributes
- self._sso_auth_handler = SSOAuthHandler(hs)
- self._http_client = hs.get_proxied_http_client()
-
- async def on_GET(self, request):
- client_redirect_url = parse_string(request, "redirectUrl", required=True)
- uri = self.cas_server_url + "/proxyValidate"
- args = {
- "ticket": parse_string(request, "ticket", required=True),
- "service": self.cas_service_url,
- }
- try:
- body = await self._http_client.get_raw(uri, args)
- except PartialDownloadError as pde:
- # Twisted raises this error if the connection is closed,
- # even if that's being used old-http style to signal end-of-data
- body = pde.response
- result = await self.handle_cas_response(request, body, client_redirect_url)
- return result
+ self._cas_handler = hs.get_cas_handler()
- def handle_cas_response(self, request, cas_response_body, client_redirect_url):
- user, attributes = self.parse_cas_response(cas_response_body)
- displayname = attributes.pop(self.cas_displayname_attribute, None)
+ async def on_GET(self, request: SynapseRequest) -> None:
+ client_redirect_url = parse_string(request, "redirectUrl")
+ ticket = parse_string(request, "ticket", required=True)
- for required_attribute, required_value in self.cas_required_attributes.items():
- # If required attribute was not in CAS Response - Forbidden
- if required_attribute not in attributes:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Maybe get a session ID (if this ticket is from user interactive
+ # authentication).
+ session = parse_string(request, "session")
- # Also need to check value
- if required_value is not None:
- actual_value = attributes[required_attribute]
- # If required attribute value does not match expected - Forbidden
- if required_value != actual_value:
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+ # Either client_redirect_url or session must be provided.
+ if not client_redirect_url and not session:
+ message = "Missing string query parameter redirectUrl or session"
+ raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
- return self._sso_auth_handler.on_successful_auth(
- user, request, client_redirect_url, displayname
+ await self._cas_handler.handle_ticket(
+ request, ticket, client_redirect_url, session
)
- def parse_cas_response(self, cas_response_body):
- user = None
- attributes = {}
- try:
- root = ET.fromstring(cas_response_body)
- if not root.tag.endswith("serviceResponse"):
- raise Exception("root of CAS response is not serviceResponse")
- success = root[0].tag.endswith("authenticationSuccess")
- for child in root[0]:
- if child.tag.endswith("user"):
- user = child.text
- if child.tag.endswith("attributes"):
- for attribute in child:
- # ElementTree library expands the namespace in
- # attribute tags to the full URL of the namespace.
- # We don't care about namespace here and it will always
- # be encased in curly braces, so we remove them.
- tag = attribute.tag
- if "}" in tag:
- tag = tag.split("}")[1]
- attributes[tag] = attribute.text
- if user is None:
- raise Exception("CAS response does not contain user")
- except Exception:
- logger.exception("Error parsing CAS response")
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
- if not success:
- raise LoginError(
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
- )
- return user, attributes
-
class SAMLRedirectServlet(BaseSSORedirectServlet):
PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -529,69 +470,25 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
def __init__(self, hs):
self._saml_handler = hs.get_saml_handler()
- def get_sso_url(self, client_redirect_url):
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
return self._saml_handler.handle_redirect_request(client_redirect_url)
-class SSOAuthHandler(object):
- """
- Utility class for Resources and Servlets which handle the response from a SSO
- service
+class OIDCRedirectServlet(BaseSSORedirectServlet):
+ """Implementation for /login/sso/redirect for the OIDC login flow."""
- Args:
- hs (synapse.server.HomeServer)
- """
+ PATTERNS = client_patterns("/login/sso/redirect", v1=True)
def __init__(self, hs):
- self._hostname = hs.hostname
- self._auth_handler = hs.get_auth_handler()
- self._registration_handler = hs.get_registration_handler()
- self._macaroon_gen = hs.get_macaroon_generator()
-
- # Load the redirect page HTML template
- self._template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
- )[0]
-
- self._server_name = hs.config.server_name
-
- # cast to tuple for use with str.startswith
- self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
-
- async def on_successful_auth(
- self, username, request, client_redirect_url, user_display_name=None
- ):
- """Called once the user has successfully authenticated with the SSO.
-
- Registers the user if necessary, and then returns a redirect (with
- a login token) to the client.
-
- Args:
- username (unicode|bytes): the remote user id. We'll map this onto
- something sane for a MXID localpath.
-
- request (SynapseRequest): the incoming request from the browser. We'll
- respond to it with a redirect.
-
- client_redirect_url (unicode): the redirect_url the client gave us when
- it first started the process.
-
- user_display_name (unicode|None): if set, and we have to register a new user,
- we will set their displayname to this.
-
- Returns:
- Deferred[none]: Completes once we have handled the request.
- """
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
- if not registered_user_id:
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart, default_display_name=user_display_name
- )
+ self._oidc_handler = hs.get_oidc_handler()
- self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ async def get_sso_url(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> bytes:
+ return await self._oidc_handler.handle_redirect_request(
+ request, client_redirect_url
)
@@ -602,3 +499,5 @@ def register_servlets(hs, http_server):
CasTicketServlet(hs).register(http_server)
elif hs.config.saml2_enabled:
SAMLRedirectServlet(hs).register(http_server)
+ elif hs.config.oidc_enabled:
+ OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1cf3caf832..b0c30b65be 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
if requester.device_id is None:
- # the acccess token wasn't associated with a device.
+ # The access token wasn't associated with a device.
# Just delete the access token
access_token = self.auth.get_access_token_from_request(request)
await self._auth_handler.delete_access_token(access_token)
@@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet):
return 200, {}
async def on_POST(self, request):
- requester = await self.auth.get_user_by_req(request)
+ requester = await self.auth.get_user_by_req(request, allow_expired=True)
user_id = requester.user.to_string()
# first delete all of the user's devices
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index e788eb0193..43b64608e7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -27,6 +27,7 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
Codes,
+ HttpResponseException,
InvalidClientCredentialsError,
SynapseError,
)
@@ -92,7 +93,7 @@ class RoomCreateRestServlet(TransactionRestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request)
- info = await self._room_creation_handler.create_room(
+ info, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -201,7 +202,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
if event_type == EventTypes.Member:
membership = content.get("membership", None)
- event = await self.room_member_handler.update_membership(
+ event_id, _ = await self.room_member_handler.update_membership(
requester,
target=UserID.from_string(state_key),
room_id=room_id,
@@ -209,14 +210,18 @@ class RoomStateEventRestServlet(TransactionRestServlet):
content=content,
)
else:
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ (
+ event,
+ _,
+ ) = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
+ event_id = event.event_id
ret = {} # type: dict
- if event:
- set_tag("event_id", event.event_id)
- ret = {"event_id": event.event_id}
+ if event_id:
+ set_tag("event_id", event_id)
+ ret = {"event_id": event_id}
return 200, ret
@@ -246,7 +251,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
if b"ts" in request.args and requester.app_service:
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester, event_dict, txn_id=txn_id
)
@@ -364,10 +369,13 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server:
- data = await handler.get_remote_public_room_list(
- server, limit=limit, since_token=since_token
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server, limit=limit, since_token=since_token
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
data = await handler.get_local_public_room_list(
limit=limit, since_token=since_token
@@ -404,15 +412,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
limit = None
handler = self.hs.get_room_list_handler()
- if server:
- data = await handler.get_remote_public_room_list(
- server,
- limit=limit,
- since_token=since_token,
- search_filter=search_filter,
- include_all_networks=include_all_networks,
- third_party_instance_id=third_party_instance_id,
- )
+ if server and server != self.hs.config.server_name:
+ try:
+ data = await handler.get_remote_public_room_list(
+ server,
+ limit=limit,
+ since_token=since_token,
+ search_filter=search_filter,
+ include_all_networks=include_all_networks,
+ third_party_instance_id=third_party_instance_id,
+ )
+ except HttpResponseException as e:
+ raise e.to_synapse_error()
else:
data = await handler.get_local_public_room_list(
limit=limit,
@@ -775,7 +786,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
- event = await self.event_creation_handler.create_and_send_nonmember_event(
+ event, _ = await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Redaction,
|