diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f687a68031..2d64ee5e44 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
+import inspect
import logging
import time
import unicodedata
@@ -24,7 +24,6 @@ import attr
import bcrypt # type: ignore[import]
import pymacaroons
-import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -38,15 +37,16 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
-from synapse.http.server import finish_request
+from synapse.http.server import finish_request, respond_with_html
from synapse.http.servlet import assert_params_in_dict
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
-from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
+from synapse.util import stringutils as stringutils
from synapse.util.msisdn import phone_number_to_msisdn
+from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -209,18 +209,17 @@ class AuthHandler(BaseHandler):
# after the SSO completes and before redirecting them back to their client.
# It notifies the user they are about to give access to their matrix account
# to the client.
- self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
- )[0]
+ self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template
+
# The following template is shown during user interactive authentication
# in the fallback auth scenario. It notifies the user that they are
# authenticating for an operation to occur on their account.
- self._sso_auth_confirm_template = load_jinja2_templates(
- hs.config.sso_template_dir, ["sso_auth_confirm.html"],
- )[0]
+ self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template
+
# The following template is shown after a successful user interactive
# authentication session. It tells the user they can close the window.
self._sso_auth_success_template = hs.config.sso_auth_success_template
+
# The following template is shown during the SSO authentication process if
# the account is deactivated.
self._sso_account_deactivated_template = (
@@ -239,7 +238,7 @@ class AuthHandler(BaseHandler):
request_body: Dict[str, Any],
clientip: str,
description: str,
- ) -> dict:
+ ) -> Tuple[dict, str]:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -260,9 +259,14 @@ class AuthHandler(BaseHandler):
describes the operation happening on their account.
Returns:
- The parameters for this request (which may
+ A tuple of (params, session_id).
+
+ 'params' contains the parameters for this request (which may
have been given only in a previous call).
+ 'session_id' is the ID of this session, either passed in by the
+ client or assigned by this call
+
Raises:
InteractiveAuthIncompleteError if the client has not yet completed
any of the permitted login flows
@@ -284,7 +288,7 @@ class AuthHandler(BaseHandler):
flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
- result, params, _ = await self.check_auth(
+ result, params, session_id = await self.check_ui_auth(
flows, request, request_body, clientip, description
)
except LoginError:
@@ -307,7 +311,7 @@ class AuthHandler(BaseHandler):
if user_id != requester.user.to_string():
raise AuthError(403, "Invalid auth")
- return params
+ return params, session_id
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -317,7 +321,7 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
- async def check_auth(
+ async def check_ui_auth(
self,
flows: List[List[str]],
request: SynapseRequest,
@@ -375,7 +379,7 @@ class AuthHandler(BaseHandler):
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
- method = request.uri.decode("utf-8")
+ method = request.method.decode("utf-8")
# If there's no session ID, create a new session.
if not sid:
@@ -438,9 +442,17 @@ class AuthHandler(BaseHandler):
# authentication flow.
await self.store.set_ui_auth_clientdict(sid, clientdict)
+ user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
+ 0
+ ].decode("ascii", "surrogateescape")
+
+ await self.store.add_user_agent_ip_to_ui_auth_session(
+ session.session_id, user_agent, clientip
+ )
+
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session.session_id)
+ session.session_id, self._auth_dict_for_flows(flows, session.session_id)
)
# check auth type currently being presented
@@ -487,7 +499,7 @@ class AuthHandler(BaseHandler):
ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
- raise InteractiveAuthIncompleteError(ret)
+ raise InteractiveAuthIncompleteError(session.session_id, ret)
async def add_oob_auth(
self, stagetype: str, authdict: Dict[str, Any], clientip: str
@@ -1055,11 +1067,15 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
- await provider.on_logged_out(
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
)
+ if inspect.isawaitable(result):
+ await result
# delete pushers associated with this access token
if user_info["token_id"] is not None:
@@ -1120,7 +1136,7 @@ class AuthHandler(BaseHandler):
# for the presence of an email address during password reset was
# case sensitive).
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
@@ -1148,7 +1164,7 @@ class AuthHandler(BaseHandler):
# 'Canonicalise' email addresses as per above
if medium == "email":
- address = address.lower()
+ address = canonicalise_email(address)
identity_handler = self.hs.get_handlers().identity_handler
result = await identity_handler.try_unbind_threepid(
@@ -1247,13 +1263,8 @@ class AuthHandler(BaseHandler):
)
# Render the HTML and return.
- html_bytes = self._sso_auth_success_template.encode("utf-8")
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
-
- request.write(html_bytes)
- finish_request(request)
+ html = self._sso_auth_success_template
+ respond_with_html(request, 200, html)
async def complete_sso_login(
self,
@@ -1273,13 +1284,7 @@ class AuthHandler(BaseHandler):
# flow.
deactivated = await self.store.get_user_deactivated_status(registered_user_id)
if deactivated:
- html_bytes = self._sso_account_deactivated_template.encode("utf-8")
-
- request.setResponseCode(403)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
- request.write(html_bytes)
- finish_request(request)
+ respond_with_html(request, 403, self._sso_account_deactivated_template)
return
self._complete_sso_login(registered_user_id, request, client_redirect_url)
@@ -1320,17 +1325,12 @@ class AuthHandler(BaseHandler):
# URL we redirect users to.
redirect_url_no_params = client_redirect_url.split("?")[0]
- html_bytes = self._sso_redirect_confirm_template.render(
+ html = self._sso_redirect_confirm_template.render(
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
- ).encode("utf-8")
-
- request.setResponseCode(200)
- request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
- request.write(html_bytes)
- finish_request(request)
+ )
+ respond_with_html(request, 200, html)
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
|