diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 2c99536678..d0d4999795 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -28,7 +28,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
-from synapse.http.site import SynapseRequest
+from synapse.push.mailer import load_jinja2_templates
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import UserID, map_username_to_mxid_localpart
@@ -548,6 +548,16 @@ class SSOAuthHandler(object):
self._registration_handler = hs.get_registration_handler()
self._macaroon_gen = hs.get_macaroon_generator()
+ # Load the redirect page HTML template
+ self._template = load_jinja2_templates(
+ hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ )[0]
+
+ self._server_name = hs.config.server_name
+
+ # cast to tuple for use with str.startswith
+ self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
async def on_successful_auth(
self, username, request, client_redirect_url, user_display_name=None
):
@@ -580,36 +590,9 @@ class SSOAuthHandler(object):
localpart=localpart, default_display_name=user_display_name
)
- self.complete_sso_login(registered_user_id, request, client_redirect_url)
-
- def complete_sso_login(
- self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str
- ):
- """Having figured out a mxid for this user, complete the HTTP request
-
- Args:
- registered_user_id:
- request:
- client_redirect_url:
- """
-
- login_token = self._macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ self._auth_handler.complete_sso_login(
+ registered_user_id, request, client_redirect_url
)
- redirect_url = self._add_login_token_to_redirect_url(
- client_redirect_url, login_token
- )
- # Load page
- request.redirect(redirect_url)
- finish_request(request)
-
- @staticmethod
- def _add_login_token_to_redirect_url(url, token):
- url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({"loginToken": token})
- url_parts[4] = urllib.parse.urlencode(query)
- return urllib.parse.urlunparse(url_parts)
def register_servlets(hs, http_server):
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 4b6d030a57..ab671f7334 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -18,8 +18,6 @@ from typing import Dict, Set
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
-from twisted.internet import defer
-
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import (
@@ -125,8 +123,7 @@ class RemoteKey(DirectServeResource):
await self.query_keys(request, query, query_remote_on_cache_miss=True)
- @defer.inlineCallbacks
- def query_keys(self, request, query, query_remote_on_cache_miss=False):
+ async def query_keys(self, request, query, query_remote_on_cache_miss=False):
logger.info("Handling query for keys %r", query)
store_queries = []
@@ -143,7 +140,7 @@ class RemoteKey(DirectServeResource):
for key_id in key_ids:
store_queries.append((server_name, key_id, None))
- cached = yield self.store.get_server_keys_json(store_queries)
+ cached = await self.store.get_server_keys_json(store_queries)
json_results = set()
@@ -215,8 +212,8 @@ class RemoteKey(DirectServeResource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- yield self.fetcher.get_keys(cache_misses)
- yield self.query_keys(request, query, query_remote_on_cache_miss=False)
+ await self.fetcher.get_keys(cache_misses)
+ await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py
index 69ecc5e4b4..a545c13db7 100644
--- a/synapse/rest/saml2/response_resource.py
+++ b/synapse/rest/saml2/response_resource.py
@@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.http.server import DirectServeResource, wrap_html_request_handler
+from synapse.http.server import (
+ DirectServeResource,
+ finish_request,
+ wrap_html_request_handler,
+)
class SAML2ResponseResource(DirectServeResource):
@@ -24,8 +28,20 @@ class SAML2ResponseResource(DirectServeResource):
def __init__(self, hs):
super().__init__()
+ self._error_html_content = hs.config.saml2_error_html_content
self._saml_handler = hs.get_saml_handler()
+ async def _async_render_GET(self, request):
+ # We're not expecting any GET request on that resource if everything goes right,
+ # but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
+ # In this case, just tell the user that something went wrong and they should
+ # try to authenticate again.
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(self._error_html_content),))
+ request.write(self._error_html_content.encode("utf8"))
+ finish_request(request)
+
@wrap_html_request_handler
async def _async_render_POST(self, request):
return await self._saml_handler.handle_saml_response(request)
|