diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 51413d910e..3b781d9836 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -126,30 +126,28 @@ class BaseHandler(object):
retry_after_ms=int(1000 * (time_allowed - time_now))
)
- @defer.inlineCallbacks
- def maybe_kick_guest_users(self, event, context=None):
+ async def maybe_kick_guest_users(self, event, context=None):
# Technically this function invalidates current_state by changing it.
# Hopefully this isn't that important to the caller.
if event.type == EventTypes.GuestAccess:
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
- current_state_ids = yield context.get_current_state_ids()
- current_state = yield self.store.get_events(
+ current_state_ids = await context.get_current_state_ids()
+ current_state = await self.store.get_events(
list(current_state_ids.values())
)
else:
- current_state = yield self.state_handler.get_current_state(
+ current_state = await self.state_handler.get_current_state(
event.room_id
)
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state)
- yield self.kick_guest_users(current_state)
+ await self.kick_guest_users(current_state)
- @defer.inlineCallbacks
- def kick_guest_users(self, current_state):
+ async def kick_guest_users(self, current_state):
for member_event in current_state:
try:
if member_event.type != EventTypes.Member:
@@ -180,7 +178,7 @@ class BaseHandler(object):
# homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True)
handler = self.hs.get_room_member_handler()
- yield handler.update_membership(
+ await handler.update_membership(
requester,
target_user,
member_event.room_id,
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 7860f9625e..524281d2f1 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -18,14 +18,12 @@ import logging
import time
import unicodedata
import urllib.parse
-from typing import Any, Dict, Iterable, List, Optional
+from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import attr
import bcrypt # type: ignore[import]
import pymacaroons
-from twisted.internet import defer
-
import synapse.util.stringutils as stringutils
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -43,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.server import finish_request
from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.push.mailer import load_jinja2_templates
from synapse.types import Requester, UserID
-from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -71,15 +69,6 @@ class AuthHandler(BaseHandler):
self.bcrypt_rounds = hs.config.bcrypt_rounds
- # This is not a cache per se, but a store of all current sessions that
- # expire after N hours
- self.sessions = ExpiringCache(
- cache_name="register_sessions",
- clock=hs.get_clock(),
- expiry_ms=self.SESSION_EXPIRE_MS,
- reset_expiry_on_get=True,
- )
-
account_handler = ModuleApi(hs, self)
self.password_providers = [
module(config=config, account_handler=account_handler)
@@ -91,6 +80,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
+ self._sso_enabled = hs.config.saml2_enabled or hs.config.cas_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -106,6 +96,13 @@ class AuthHandler(BaseHandler):
if t not in login_types:
login_types.append(t)
self._supported_login_types = login_types
+ # Login types and UI Auth types have a heavy overlap, but are not
+ # necessarily identical. Login types have SSO (and other login types)
+ # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
+ ui_auth_types = login_types.copy()
+ if self._sso_enabled:
+ ui_auth_types.append(LoginType.SSO)
+ self._supported_ui_auth_types = ui_auth_types
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -113,20 +110,52 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
- # Load the SSO redirect confirmation page HTML template
+ # Expire old UI auth sessions after a period of time.
+ if hs.config.worker_app is None:
+ self._clock.looping_call(
+ run_as_background_process,
+ 5 * 60 * 1000,
+ "expire_old_sessions",
+ self._expire_old_sessions,
+ )
+
+ # Load the SSO HTML templates.
+
+ # The following template is shown to the user during a client login via SSO,
+ # after the SSO completes and before redirecting them back to their client.
+ # It notifies the user they are about to give access to their matrix account
+ # to the client.
self._sso_redirect_confirm_template = load_jinja2_templates(
- hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ hs.config.sso_template_dir, ["sso_redirect_confirm.html"],
)[0]
+ # The following template is shown during user interactive authentication
+ # in the fallback auth scenario. It notifies the user that they are
+ # authenticating for an operation to occur on their account.
+ self._sso_auth_confirm_template = load_jinja2_templates(
+ hs.config.sso_template_dir, ["sso_auth_confirm.html"],
+ )[0]
+ # The following template is shown after a successful user interactive
+ # authentication session. It tells the user they can close the window.
+ self._sso_auth_success_template = hs.config.sso_auth_success_template
+ # The following template is shown during the SSO authentication process if
+ # the account is deactivated.
+ self._sso_account_deactivated_template = (
+ hs.config.sso_account_deactivated_template
+ )
self._server_name = hs.config.server_name
# cast to tuple for use with str.startswith
self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
- @defer.inlineCallbacks
- def validate_user_via_ui_auth(
- self, requester: Requester, request_body: Dict[str, Any], clientip: str
- ):
+ async def validate_user_via_ui_auth(
+ self,
+ requester: Requester,
+ request: SynapseRequest,
+ request_body: Dict[str, Any],
+ clientip: str,
+ description: str,
+ ) -> dict:
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -137,12 +166,17 @@ class AuthHandler(BaseHandler):
Args:
requester: The user, as given by the access token
+ request: The request sent by the client.
+
request_body: The body of the request sent by the client
clientip: The IP address of the client.
+ description: A human readable string to be displayed to the user that
+ describes the operation happening on their account.
+
Returns:
- defer.Deferred[dict]: the parameters for this request (which may
+ The parameters for this request (which may
have been given only in a previous call).
Raises:
@@ -169,10 +203,12 @@ class AuthHandler(BaseHandler):
)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_login_types]
+ flows = [[login_type] for login_type in self._supported_ui_auth_types]
try:
- result, params, _ = yield self.check_auth(flows, request_body, clientip)
+ result, params, _ = await self.check_auth(
+ flows, request, request_body, clientip, description
+ )
except LoginError:
# Update the ratelimite to say we failed (`can_do_action` doesn't raise).
self._failed_uia_attempts_ratelimiter.can_do_action(
@@ -185,7 +221,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_login_types:
+ for login_type in self._supported_ui_auth_types:
if login_type not in result:
continue
@@ -209,18 +245,18 @@ class AuthHandler(BaseHandler):
"""
return self.checkers.keys()
- @defer.inlineCallbacks
- def check_auth(
- self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
- ):
+ async def check_auth(
+ self,
+ flows: List[List[str]],
+ request: SynapseRequest,
+ clientdict: Dict[str, Any],
+ clientip: str,
+ description: str,
+ ) -> Tuple[dict, dict, str]:
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
- As a side effect, this function fills in the 'creds' key on the user's
- session with a map, which maps each auth-type (str) to the relevant
- identity authenticated by that auth-type (mostly str, but for captcha, bool).
-
If no auth flows have been completed successfully, raises an
InteractiveAuthIncompleteError. To handle this, you can use
synapse.rest.client.v2_alpha._base.interactive_auth_handler as a
@@ -231,14 +267,18 @@ class AuthHandler(BaseHandler):
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
+ request: The request sent by the client.
+
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip: The IP address of the client.
+ description: A human readable string to be displayed to the user that
+ describes the operation happening on their account.
+
Returns:
- defer.Deferred[dict, dict, str]: a deferred tuple of
- (creds, params, session_id).
+ A tuple of (creds, params, session_id).
'creds' contains the authenticated credentials of each stage.
@@ -260,9 +300,26 @@ class AuthHandler(BaseHandler):
del clientdict["auth"]
if "session" in authdict:
sid = authdict["session"]
- session = self._get_session_info(sid)
- if len(clientdict) > 0:
+ # Convert the URI and method to strings.
+ uri = request.uri.decode("utf-8")
+ method = request.uri.decode("utf-8")
+
+ # If there's no session ID, create a new session.
+ if not sid:
+ session = await self.store.create_ui_auth_session(
+ clientdict, uri, method, description
+ )
+
+ else:
+ try:
+ session = await self.store.get_ui_auth_session(sid)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (sid,))
+
+ # If the client provides parameters, update what is persisted,
+ # otherwise use whatever was last provided.
+ #
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
@@ -270,31 +327,60 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a homeserver.
- # Revisit: Assumimg the REST APIs do sensible validation, the data
- # isn't arbintrary.
- session["clientdict"] = clientdict
- self._save_session(session)
- elif "clientdict" in session:
- clientdict = session["clientdict"]
+ #
+ # Revisit: Assuming the REST APIs do sensible validation, the data
+ # isn't arbitrary.
+ #
+ # Note that the registration endpoint explicitly removes the
+ # "initial_device_display_name" parameter if it is provided
+ # without a "password" parameter. See the changes to
+ # synapse.rest.client.v2_alpha.register.RegisterRestServlet.on_POST
+ # in commit 544722bad23fc31056b9240189c3cbbbf0ffd3f9.
+ if not clientdict:
+ clientdict = session.clientdict
+
+ # Ensure that the queried operation does not vary between stages of
+ # the UI authentication session. This is done by generating a stable
+ # comparator and storing it during the initial query. Subsequent
+ # queries ensure that this comparator has not changed.
+ #
+ # The comparator is based on the requested URI and HTTP method. The
+ # client dict (minus the auth dict) should also be checked, but some
+ # clients are not spec compliant, just warn for now if the client
+ # dict changes.
+ if (session.uri, session.method) != (uri, method):
+ raise SynapseError(
+ 403,
+ "Requested operation has changed during the UI authentication session.",
+ )
+
+ if session.clientdict != clientdict:
+ logger.warning(
+ "Requested operation has changed during the UI "
+ "authentication session. A future version of Synapse "
+ "will remove this capability."
+ )
+
+ # For backwards compatibility, changes to the client dict are
+ # persisted as clients modify them throughout their user interactive
+ # authentication flow.
+ await self.store.set_ui_auth_clientdict(sid, clientdict)
if not authdict:
raise InteractiveAuthIncompleteError(
- self._auth_dict_for_flows(flows, session)
+ self._auth_dict_for_flows(flows, session.session_id)
)
- if "creds" not in session:
- session["creds"] = {}
- creds = session["creds"]
-
# check auth type currently being presented
errordict = {} # type: Dict[str, Any]
if "type" in authdict:
login_type = authdict["type"] # type: str
try:
- result = yield self._check_auth_dict(authdict, clientip)
+ result = await self._check_auth_dict(authdict, clientip)
if result:
- creds[login_type] = result
- self._save_session(session)
+ await self.store.mark_ui_auth_stage_complete(
+ session.session_id, login_type, result
+ )
except LoginError as e:
if login_type == LoginType.EMAIL_IDENTITY:
# riot used to have a bug where it would request a new
@@ -310,6 +396,7 @@ class AuthHandler(BaseHandler):
# so that the client can have another go.
errordict = e.error_dict()
+ creds = await self.store.get_completed_ui_auth_stages(session.session_id)
for f in flows:
if len(set(f) - set(creds)) == 0:
# it's very useful to know what args are stored, but this can
@@ -322,15 +409,17 @@ class AuthHandler(BaseHandler):
creds,
list(clientdict),
)
- return creds, clientdict, session["id"]
- ret = self._auth_dict_for_flows(flows, session)
+ return creds, clientdict, session.session_id
+
+ ret = self._auth_dict_for_flows(flows, session.session_id)
ret["completed"] = list(creds)
ret.update(errordict)
raise InteractiveAuthIncompleteError(ret)
- @defer.inlineCallbacks
- def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
+ async def add_oob_auth(
+ self, stagetype: str, authdict: Dict[str, Any], clientip: str
+ ) -> bool:
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@@ -340,15 +429,11 @@ class AuthHandler(BaseHandler):
if "session" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
- sess = self._get_session_info(authdict["session"])
- if "creds" not in sess:
- sess["creds"] = {}
- creds = sess["creds"]
-
- result = yield self.checkers[stagetype].check_auth(authdict, clientip)
+ result = await self.checkers[stagetype].check_auth(authdict, clientip)
if result:
- creds[stagetype] = result
- self._save_session(sess)
+ await self.store.mark_ui_auth_stage_complete(
+ authdict["session"], stagetype, result
+ )
return True
return False
@@ -370,7 +455,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id: str, key: str, value: Any) -> None:
+ async def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
@@ -381,11 +466,12 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
value: The data to store
"""
- sess = self._get_session_info(session_id)
- sess.setdefault("serverdict", {})[key] = value
- self._save_session(sess)
+ try:
+ await self.store.set_ui_auth_session_data(session_id, key, value)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- def get_session_data(
+ async def get_session_data(
self, session_id: str, key: str, default: Optional[Any] = None
) -> Any:
"""
@@ -396,11 +482,22 @@ class AuthHandler(BaseHandler):
key: The key to store the data under
default: Value to return if the key has not been set
"""
- sess = self._get_session_info(session_id)
- return sess.setdefault("serverdict", {}).get(key, default)
+ try:
+ return await self.store.get_ui_auth_session_data(session_id, key, default)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
- @defer.inlineCallbacks
- def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
+ async def _expire_old_sessions(self):
+ """
+ Invalidate any user interactive authentication sessions that have expired.
+ """
+ now = self._clock.time_msec()
+ expiration_time = now - self.SESSION_EXPIRE_MS
+ await self.store.delete_old_ui_auth_sessions(expiration_time)
+
+ async def _check_auth_dict(
+ self, authdict: Dict[str, Any], clientip: str
+ ) -> Union[Dict[str, Any], str]:
"""Attempt to validate the auth dict provided by a client
Args:
@@ -408,7 +505,7 @@ class AuthHandler(BaseHandler):
clientip: IP address of the client
Returns:
- Deferred: result of the stage verification.
+ Result of the stage verification.
Raises:
StoreError if there was a problem accessing the database
@@ -418,7 +515,7 @@ class AuthHandler(BaseHandler):
login_type = authdict["type"]
checker = self.checkers.get(login_type)
if checker is not None:
- res = yield checker.check_auth(authdict, clientip=clientip)
+ res = await checker.check_auth(authdict, clientip=clientip)
return res
# build a v1-login-style dict out of the authdict and fall back to the
@@ -428,7 +525,7 @@ class AuthHandler(BaseHandler):
if user_id is None:
raise SynapseError(400, "", Codes.MISSING_PARAM)
- (canonical_id, callback) = yield self.validate_login(user_id, authdict)
+ (canonical_id, callback) = await self.validate_login(user_id, authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -452,7 +549,7 @@ class AuthHandler(BaseHandler):
}
def _auth_dict_for_flows(
- self, flows: List[List[str]], session: Dict[str, Any]
+ self, flows: List[List[str]], session_id: str,
) -> Dict[str, Any]:
public_flows = []
for f in flows:
@@ -471,31 +568,12 @@ class AuthHandler(BaseHandler):
params[stage] = get_params[stage]()
return {
- "session": session["id"],
+ "session": session_id,
"flows": [{"stages": f} for f in public_flows],
"params": params,
}
- def _get_session_info(self, session_id: Optional[str]) -> dict:
- """
- Gets or creates a session given a session ID.
-
- The session can be used to track data across multiple requests, e.g. for
- interactive authentication.
- """
- if session_id not in self.sessions:
- session_id = None
-
- if not session_id:
- # create a new session
- while session_id is None or session_id in self.sessions:
- session_id = stringutils.random_string(24)
- self.sessions[session_id] = {"id": session_id}
-
- return self.sessions[session_id]
-
- @defer.inlineCallbacks
- def get_access_token_for_user_id(
+ async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
):
"""
@@ -525,10 +603,10 @@ class AuthHandler(BaseHandler):
)
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
- yield self.store.add_access_token_to_user(
+ await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms
)
@@ -538,15 +616,14 @@ class AuthHandler(BaseHandler):
# device, so we double-check it here.
if device_id is not None:
try:
- yield self.store.get_device(user_id, device_id)
+ await self.store.get_device(user_id, device_id)
except StoreError:
- yield self.store.delete_access_token(access_token)
+ await self.store.delete_access_token(access_token)
raise StoreError(400, "Login raced against device deletion")
return access_token
- @defer.inlineCallbacks
- def check_user_exists(self, user_id: str):
+ async def check_user_exists(self, user_id: str) -> Optional[str]:
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
@@ -555,28 +632,25 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
Returns:
- defer.Deferred: (unicode) canonical_user_id, or None if zero or
- multiple matches
-
- Raises:
- UserDeactivatedError if a user is found but is deactivated.
+ The canonical_user_id, or None if zero or multiple matches
"""
- res = yield self._find_user_id_and_pwd_hash(user_id)
+ res = await self._find_user_id_and_pwd_hash(user_id)
if res is not None:
return res[0]
return None
- @defer.inlineCallbacks
- def _find_user_id_and_pwd_hash(self, user_id: str):
+ async def _find_user_id_and_pwd_hash(
+ self, user_id: str
+ ) -> Optional[Tuple[str, str]]:
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
Returns:
- tuple: A 2-tuple of `(canonical_user_id, password_hash)`
- None: if there is not exactly one match
+ A 2-tuple of `(canonical_user_id, password_hash)` or `None`
+ if there is not exactly one match
"""
- user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
+ user_infos = await self.store.get_users_by_id_case_insensitive(user_id)
result = None
if not user_infos:
@@ -609,8 +683,9 @@ class AuthHandler(BaseHandler):
"""
return self._supported_login_types
- @defer.inlineCallbacks
- def validate_login(self, username: str, login_submission: Dict[str, Any]):
+ async def validate_login(
+ self, username: str, login_submission: Dict[str, Any]
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
@@ -621,7 +696,7 @@ class AuthHandler(BaseHandler):
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
- Deferred[str, func]: canonical user id, and optional callback
+ A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
Raises:
StoreError if there was a problem accessing the database
@@ -650,7 +725,7 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
known_login_type = True
- is_valid = yield provider.check_password(qualified_user_id, password)
+ is_valid = await provider.check_password(qualified_user_id, password)
if is_valid:
return qualified_user_id, None
@@ -682,7 +757,7 @@ class AuthHandler(BaseHandler):
% (login_type, missing_fields),
)
- result = yield provider.check_auth(username, login_type, login_dict)
+ result = await provider.check_auth(username, login_type, login_dict)
if result:
if isinstance(result, str):
result = (result, None)
@@ -691,8 +766,8 @@ class AuthHandler(BaseHandler):
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
- canonical_user_id = yield self._check_local_password(
- qualified_user_id, password
+ canonical_user_id = await self._check_local_password(
+ qualified_user_id, password # type: ignore
)
if canonical_user_id:
@@ -705,8 +780,9 @@ class AuthHandler(BaseHandler):
# login, it turns all LoginErrors into a 401 anyway.
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
- @defer.inlineCallbacks
- def check_password_provider_3pid(self, medium: str, address: str, password: str):
+ async def check_password_provider_3pid(
+ self, medium: str, address: str, password: str
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -715,9 +791,8 @@ class AuthHandler(BaseHandler):
password: The password of the user.
Returns:
- Deferred[(str|None, func|None)]: A tuple of `(user_id,
- callback)`. If authentication is successful, `user_id` is a `str`
- containing the authenticated, canonical user ID. `callback` is
+ A tuple of `(user_id, callback)`. If authentication is successful,
+ `user_id`is the authenticated, canonical user ID. `callback` is
then either a function to be later run after the server has
completed login/registration, or `None`. If authentication was
unsuccessful, `user_id` and `callback` are both `None`.
@@ -729,7 +804,7 @@ class AuthHandler(BaseHandler):
# success, to a str (which is the user_id) or a tuple of
# (user_id, callback_func), where callback_func should be run
# after we've finished everything else
- result = yield provider.check_3pid_auth(medium, address, password)
+ result = await provider.check_3pid_auth(medium, address, password)
if result:
# Check if the return value is a str or a tuple
if isinstance(result, str):
@@ -739,8 +814,7 @@ class AuthHandler(BaseHandler):
return None, None
- @defer.inlineCallbacks
- def _check_local_password(self, user_id: str, password: str):
+ async def _check_local_password(self, user_id: str, password: str) -> Optional[str]:
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
@@ -750,28 +824,26 @@ class AuthHandler(BaseHandler):
user_id: complete @user:id
password: the provided password
Returns:
- Deferred[unicode] the canonical_user_id, or Deferred[None] if
- unknown user/bad password
+ The canonical_user_id, or None if unknown user/bad password
"""
- lookupres = yield self._find_user_id_and_pwd_hash(user_id)
+ lookupres = await self._find_user_id_and_pwd_hash(user_id)
if not lookupres:
return None
(user_id, password_hash) = lookupres
# If the password hash is None, the account has likely been deactivated
if not password_hash:
- deactivated = yield self.store.get_user_deactivated_status(user_id)
+ deactivated = await self.store.get_user_deactivated_status(user_id)
if deactivated:
raise UserDeactivatedError("This account has been deactivated")
- result = yield self.validate_hash(password, password_hash)
+ result = await self.validate_hash(password, password_hash)
if not result:
logger.warning("Failed password login for user %s", user_id)
return None
return user_id
- @defer.inlineCallbacks
- def validate_short_term_login_token_and_get_user_id(self, login_token: str):
+ async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@@ -781,26 +853,23 @@ class AuthHandler(BaseHandler):
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
return user_id
- @defer.inlineCallbacks
- def delete_access_token(self, access_token: str):
+ async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
access_token: access token to be deleted
- Returns:
- Deferred
"""
- user_info = yield self.auth.get_user_by_access_token(access_token)
- yield self.store.delete_access_token(access_token)
+ user_info = await self.auth.get_user_by_access_token(access_token)
+ await self.store.delete_access_token(access_token)
# see if any of our auth providers want to know about this
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
- yield provider.on_logged_out(
+ await provider.on_logged_out(
user_id=str(user_info["user"]),
device_id=user_info["device_id"],
access_token=access_token,
@@ -808,12 +877,11 @@ class AuthHandler(BaseHandler):
# delete pushers associated with this access token
if user_info["token_id"] is not None:
- yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],)
)
- @defer.inlineCallbacks
- def delete_access_tokens_for_user(
+ async def delete_access_tokens_for_user(
self,
user_id: str,
except_token_id: Optional[str] = None,
@@ -827,10 +895,8 @@ class AuthHandler(BaseHandler):
device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
- Returns:
- Deferred
"""
- tokens_and_devices = yield self.store.user_delete_access_tokens(
+ tokens_and_devices = await self.store.user_delete_access_tokens(
user_id, except_token_id=except_token_id, device_id=device_id
)
@@ -838,17 +904,18 @@ class AuthHandler(BaseHandler):
for provider in self.password_providers:
if hasattr(provider, "on_logged_out"):
for token, token_id, device_id in tokens_and_devices:
- yield provider.on_logged_out(
+ await provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
# delete pushers associated with the access tokens
- yield self.hs.get_pusherpool().remove_pushers_by_access_token(
+ await self.hs.get_pusherpool().remove_pushers_by_access_token(
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
- @defer.inlineCallbacks
- def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
+ async def add_threepid(
+ self, user_id: str, medium: str, address: str, validated_at: int
+ ):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -869,14 +936,13 @@ class AuthHandler(BaseHandler):
if medium == "email":
address = address.lower()
- yield self.store.user_add_threepid(
+ await self.store.user_add_threepid(
user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
)
- @defer.inlineCallbacks
- def delete_threepid(
+ async def delete_threepid(
self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
- ):
+ ) -> bool:
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
@@ -889,7 +955,7 @@ class AuthHandler(BaseHandler):
identity server specified when binding (if known).
Returns:
- Deferred[bool]: Returns True if successfully unbound the 3pid on
+ Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
unbind API.
"""
@@ -899,28 +965,21 @@ class AuthHandler(BaseHandler):
address = address.lower()
identity_handler = self.hs.get_handlers().identity_handler
- result = yield identity_handler.try_unbind_threepid(
+ result = await identity_handler.try_unbind_threepid(
user_id, {"medium": medium, "address": address, "id_server": id_server}
)
- yield self.store.user_delete_threepid(user_id, medium, address)
+ await self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session: Dict[str, Any]) -> None:
- """Update the last used time on the session to now and add it back to the session store."""
- # TODO: Persistent storage
- logger.debug("Saving session %s", session)
- session["last_used"] = self.hs.get_clock().time_msec()
- self.sessions[session["id"]] = session
-
- def hash(self, password: str):
+ async def hash(self, password: str) -> str:
"""Computes a secure hash of password.
Args:
password: Password to hash.
Returns:
- Deferred(unicode): Hashed password.
+ Hashed password.
"""
def _do_hash():
@@ -932,9 +991,11 @@ class AuthHandler(BaseHandler):
bcrypt.gensalt(self.bcrypt_rounds),
).decode("ascii")
- return defer_to_thread(self.hs.get_reactor(), _do_hash)
+ return await defer_to_thread(self.hs.get_reactor(), _do_hash)
- def validate_hash(self, password: str, stored_hash: bytes):
+ async def validate_hash(
+ self, password: str, stored_hash: Union[bytes, str]
+ ) -> bool:
"""Validates that self.hash(password) == stored_hash.
Args:
@@ -942,7 +1003,7 @@ class AuthHandler(BaseHandler):
stored_hash: Expected hash value.
Returns:
- Deferred(bool): Whether self.hash(password) == stored_hash.
+ Whether self.hash(password) == stored_hash.
"""
def _do_validate_hash():
@@ -958,11 +1019,57 @@ class AuthHandler(BaseHandler):
if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii")
- return defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+ return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
else:
- return defer.succeed(False)
+ return False
+
+ async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str:
+ """
+ Get the HTML for the SSO redirect confirmation page.
+
+ Args:
+ redirect_url: The URL to redirect to the SSO provider.
+ session_id: The user interactive authentication session ID.
+
+ Returns:
+ The HTML to render.
+ """
+ try:
+ session = await self.store.get_ui_auth_session(session_id)
+ except StoreError:
+ raise SynapseError(400, "Unknown session ID: %s" % (session_id,))
+ return self._sso_auth_confirm_template.render(
+ description=session.description, redirect_url=redirect_url,
+ )
+
+ async def complete_sso_ui_auth(
+ self, registered_user_id: str, session_id: str, request: SynapseRequest,
+ ):
+ """Having figured out a mxid for this user, complete the HTTP request
+
+ Args:
+ registered_user_id: The registered user ID to complete SSO login for.
+ request: The request to complete.
+ client_redirect_url: The URL to which to redirect the user at the end of the
+ process.
+ """
+ # Mark the stage of the authentication as successful.
+ # Save the user who authenticated with SSO, this will be used to ensure
+ # that the account be modified is also the person who logged in.
+ await self.store.mark_ui_auth_stage_complete(
+ session_id, LoginType.SSO, registered_user_id
+ )
- def complete_sso_login(
+ # Render the HTML and return.
+ html_bytes = self._sso_auth_success_template.encode("utf-8")
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+
+ request.write(html_bytes)
+ finish_request(request)
+
+ async def complete_sso_login(
self,
registered_user_id: str,
request: SynapseRequest,
@@ -976,6 +1083,32 @@ class AuthHandler(BaseHandler):
client_redirect_url: The URL to which to redirect the user at the end of the
process.
"""
+ # If the account has been deactivated, do not proceed with the login
+ # flow.
+ deactivated = await self.store.get_user_deactivated_status(registered_user_id)
+ if deactivated:
+ html_bytes = self._sso_account_deactivated_template.encode("utf-8")
+
+ request.setResponseCode(403)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+ request.write(html_bytes)
+ finish_request(request)
+ return
+
+ self._complete_sso_login(registered_user_id, request, client_redirect_url)
+
+ def _complete_sso_login(
+ self,
+ registered_user_id: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ):
+ """
+ The synchronous portion of complete_sso_login.
+
+ This exists purely for backwards compatibility of synapse.module_api.ModuleApi.
+ """
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
@@ -1001,7 +1134,7 @@ class AuthHandler(BaseHandler):
# URL we redirect users to.
redirect_url_no_params = client_redirect_url.split("?")[0]
- html = self._sso_redirect_confirm_template.render(
+ html_bytes = self._sso_redirect_confirm_template.render(
display_url=redirect_url_no_params,
redirect_url=redirect_url,
server_name=self._server_name,
@@ -1009,8 +1142,8 @@ class AuthHandler(BaseHandler):
request.setResponseCode(200)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
- request.setHeader(b"Content-Length", b"%d" % (len(html),))
- request.write(html)
+ request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
+ request.write(html_bytes)
finish_request(request)
@staticmethod
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
new file mode 100644
index 0000000000..64aaa1335c
--- /dev/null
+++ b/synapse/handlers/cas_handler.py
@@ -0,0 +1,221 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import xml.etree.ElementTree as ET
+from typing import Dict, Optional, Tuple
+
+from six.moves import urllib
+
+from twisted.web.client import PartialDownloadError
+
+from synapse.api.errors import Codes, LoginError
+from synapse.http.site import SynapseRequest
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+
+class CasHandler:
+ """
+ Utility class for to handle the response from a CAS SSO service.
+
+ Args:
+ hs (synapse.server.HomeServer)
+ """
+
+ def __init__(self, hs):
+ self._hostname = hs.hostname
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+
+ self._cas_server_url = hs.config.cas_server_url
+ self._cas_service_url = hs.config.cas_service_url
+ self._cas_displayname_attribute = hs.config.cas_displayname_attribute
+ self._cas_required_attributes = hs.config.cas_required_attributes
+
+ self._http_client = hs.get_proxied_http_client()
+
+ def _build_service_param(self, args: Dict[str, str]) -> str:
+ """
+ Generates a value to use as the "service" parameter when redirecting or
+ querying the CAS service.
+
+ Args:
+ args: Additional arguments to include in the final redirect URL.
+
+ Returns:
+ The URL to use as a "service" parameter.
+ """
+ return "%s%s?%s" % (
+ self._cas_service_url,
+ "/_matrix/client/r0/login/cas/ticket",
+ urllib.parse.urlencode(args),
+ )
+
+ async def _validate_ticket(
+ self, ticket: str, service_args: Dict[str, str]
+ ) -> Tuple[str, Optional[str]]:
+ """
+ Validate a CAS ticket with the server, parse the response, and return the user and display name.
+
+ Args:
+ ticket: The CAS ticket from the client.
+ service_args: Additional arguments to include in the service URL.
+ Should be the same as those passed to `get_redirect_url`.
+ """
+ uri = self._cas_server_url + "/proxyValidate"
+ args = {
+ "ticket": ticket,
+ "service": self._build_service_param(service_args),
+ }
+ try:
+ body = await self._http_client.get_raw(uri, args)
+ except PartialDownloadError as pde:
+ # Twisted raises this error if the connection is closed,
+ # even if that's being used old-http style to signal end-of-data
+ body = pde.response
+
+ user, attributes = self._parse_cas_response(body)
+ displayname = attributes.pop(self._cas_displayname_attribute, None)
+
+ for required_attribute, required_value in self._cas_required_attributes.items():
+ # If required attribute was not in CAS Response - Forbidden
+ if required_attribute not in attributes:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ # Also need to check value
+ if required_value is not None:
+ actual_value = attributes[required_attribute]
+ # If required attribute value does not match expected - Forbidden
+ if required_value != actual_value:
+ raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
+
+ return user, displayname
+
+ def _parse_cas_response(
+ self, cas_response_body: str
+ ) -> Tuple[str, Dict[str, Optional[str]]]:
+ """
+ Retrieve the user and other parameters from the CAS response.
+
+ Args:
+ cas_response_body: The response from the CAS query.
+
+ Returns:
+ A tuple of the user and a mapping of other attributes.
+ """
+ user = None
+ attributes = {}
+ try:
+ root = ET.fromstring(cas_response_body)
+ if not root.tag.endswith("serviceResponse"):
+ raise Exception("root of CAS response is not serviceResponse")
+ success = root[0].tag.endswith("authenticationSuccess")
+ for child in root[0]:
+ if child.tag.endswith("user"):
+ user = child.text
+ if child.tag.endswith("attributes"):
+ for attribute in child:
+ # ElementTree library expands the namespace in
+ # attribute tags to the full URL of the namespace.
+ # We don't care about namespace here and it will always
+ # be encased in curly braces, so we remove them.
+ tag = attribute.tag
+ if "}" in tag:
+ tag = tag.split("}")[1]
+ attributes[tag] = attribute.text
+ if user is None:
+ raise Exception("CAS response does not contain user")
+ except Exception:
+ logger.exception("Error parsing CAS response")
+ raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
+ if not success:
+ raise LoginError(
+ 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
+ )
+ return user, attributes
+
+ def get_redirect_url(self, service_args: Dict[str, str]) -> str:
+ """
+ Generates a URL for the CAS server where the client should be redirected.
+
+ Args:
+ service_args: Additional arguments to include in the final redirect URL.
+
+ Returns:
+ The URL to redirect the client to.
+ """
+ args = urllib.parse.urlencode(
+ {"service": self._build_service_param(service_args)}
+ )
+
+ return "%s/login?%s" % (self._cas_server_url, args)
+
+ async def handle_ticket(
+ self,
+ request: SynapseRequest,
+ ticket: str,
+ client_redirect_url: Optional[str],
+ session: Optional[str],
+ ) -> None:
+ """
+ Called once the user has successfully authenticated with the SSO.
+ Validates a CAS ticket sent by the client and completes the auth process.
+
+ If the user interactive authentication session is provided, marks the
+ UI Auth session as complete, then returns an HTML page notifying the
+ user they are done.
+
+ Otherwise, this registers the user if necessary, and then returns a
+ redirect (with a login token) to the client.
+
+ Args:
+ request: the incoming request from the browser. We'll
+ respond to it with a redirect or an HTML page.
+
+ ticket: The CAS ticket provided by the client.
+
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
+
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
+ This should be the UI Auth session id.
+ """
+ args = {}
+ if client_redirect_url:
+ args["redirectUrl"] = client_redirect_url
+ if session:
+ args["session"] = session
+ username, user_display_name = await self._validate_ticket(ticket, args)
+
+ localpart = map_username_to_mxid_localpart(username)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+ if session:
+ await self._auth_handler.complete_sso_ui_auth(
+ registered_user_id, session, request,
+ )
+
+ else:
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=user_display_name
+ )
+
+ await self._auth_handler.complete_sso_login(
+ registered_user_id, request, client_redirect_url
+ )
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 993499f446..9bd941b5a0 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
+ yield defer.ensureDeferred(
+ self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
+ )
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, device_id=device_id
+ yield defer.ensureDeferred(
+ self._auth_handler.delete_access_tokens_for_user(
+ user_id, device_id=device_id
+ )
)
yield self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 1d842c369b..f2f16b1e43 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -86,8 +86,7 @@ class DirectoryHandler(BaseHandler):
room_alias, room_id, servers, creator=creator
)
- @defer.inlineCallbacks
- def create_association(
+ async def create_association(
self,
requester: Requester,
room_alias: RoomAlias,
@@ -127,8 +126,12 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
else:
- if self.require_membership and check_membership:
- rooms_for_user = yield self.store.get_rooms_for_user(user_id)
+ # Server admins are not subject to the same constraints as normal
+ # users when creating an alias (e.g. being in the room).
+ is_admin = await self.auth.is_server_admin(requester.user)
+
+ if (self.require_membership and check_membership) and not is_admin:
+ rooms_for_user = await self.store.get_rooms_for_user(user_id)
if room_id not in rooms_for_user:
raise AuthError(
403, "You must be in the room to create an alias for it"
@@ -145,7 +148,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to create alias")
- can_create = yield self.can_modify_alias(room_alias, user_id=user_id)
+ can_create = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_create:
raise AuthError(
400,
@@ -153,10 +156,9 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- yield self._create_association(room_alias, room_id, servers, creator=user_id)
+ await self._create_association(room_alias, room_id, servers, creator=user_id)
- @defer.inlineCallbacks
- def delete_association(self, requester: Requester, room_alias: RoomAlias):
+ async def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
@@ -180,7 +182,7 @@ class DirectoryHandler(BaseHandler):
user_id = requester.user.to_string()
try:
- can_delete = yield self._user_can_delete_alias(room_alias, user_id)
+ can_delete = await self._user_can_delete_alias(room_alias, user_id)
except StoreError as e:
if e.code == 404:
raise NotFoundError("Unknown room alias")
@@ -189,7 +191,7 @@ class DirectoryHandler(BaseHandler):
if not can_delete:
raise AuthError(403, "You don't have permission to delete the alias.")
- can_delete = yield self.can_modify_alias(room_alias, user_id=user_id)
+ can_delete = await self.can_modify_alias(room_alias, user_id=user_id)
if not can_delete:
raise SynapseError(
400,
@@ -197,10 +199,10 @@ class DirectoryHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- room_id = yield self._delete_association(room_alias)
+ room_id = await self._delete_association(room_alias)
try:
- yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
+ await self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
@@ -292,15 +294,14 @@ class DirectoryHandler(BaseHandler):
Codes.NOT_FOUND,
)
- @defer.inlineCallbacks
- def _update_canonical_alias(
+ async def _update_canonical_alias(
self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
"""
- alias_event = yield self.state.get_current_state(
+ alias_event = await self.state.get_current_state(
room_id, EventTypes.CanonicalAlias, ""
)
@@ -331,7 +332,7 @@ class DirectoryHandler(BaseHandler):
del content["alt_aliases"]
if send_update:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -372,8 +373,7 @@ class DirectoryHandler(BaseHandler):
# either no interested services, or no service with an exclusive lock
return defer.succeed(True)
- @defer.inlineCallbacks
- def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+ async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
"""Determine whether a user can delete an alias.
One of the following must be true:
@@ -384,24 +384,23 @@ class DirectoryHandler(BaseHandler):
for the current room.
"""
- creator = yield self.store.get_room_alias_creator(alias.to_string())
+ creator = await self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
# Resolve the alias to the corresponding room.
- room_mapping = yield self.get_association(alias)
+ room_mapping = await self.get_association(alias)
room_id = room_mapping["room_id"]
if not room_id:
return False
- res = yield self.auth.check_can_change_room_list(
+ res = await self.auth.check_can_change_room_list(
room_id, UserID.from_string(user_id)
)
return res
- @defer.inlineCallbacks
- def edit_published_room_list(
+ async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str
):
"""Edit the entry of the room in the published room list.
@@ -429,11 +428,11 @@ class DirectoryHandler(BaseHandler):
403, "This user is not permitted to publish rooms to the room list"
)
- room = yield self.store.get_room(room_id)
+ room = await self.store.get_room(room_id)
if room is None:
raise SynapseError(400, "Unknown room")
- can_change_room_list = yield self.auth.check_can_change_room_list(
+ can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user
)
if not can_change_room_list:
@@ -445,8 +444,8 @@ class DirectoryHandler(BaseHandler):
making_public = visibility == "public"
if making_public:
- room_aliases = yield self.store.get_aliases_for_room(room_id)
- canonical_alias = yield self.store.get_canonical_alias_for_room(room_id)
+ room_aliases = await self.store.get_aliases_for_room(room_id)
+ canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
if canonical_alias:
room_aliases.append(canonical_alias)
@@ -458,7 +457,7 @@ class DirectoryHandler(BaseHandler):
# per alias creation rule?
raise SynapseError(403, "Not allowed to publish room")
- yield self.store.set_room_is_public(room_id, making_public)
+ await self.store.set_room_is_public(room_id, making_public)
@defer.inlineCallbacks
def edit_published_appservice_room_list(
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ec18a42a68..71a89f09c7 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -19,6 +19,7 @@ import random
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import AuthError, SynapseError
from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
from synapse.logging.utils import log_function
from synapse.types import UserID
from synapse.visibility import filter_events_for_client
@@ -97,6 +98,8 @@ class EventStreamHandler(BaseHandler):
explicit_room_id=room_id,
)
+ time_now = self.clock.time_msec()
+
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add = []
@@ -112,19 +115,20 @@ class EventStreamHandler(BaseHandler):
users = await self.state.get_current_users_in_room(
event.room_id
)
- states = await presence_handler.get_states(users, as_event=True)
- to_add.extend(states)
else:
+ users = [event.state_key]
- ev = await presence_handler.get_state(
- UserID.from_string(event.state_key), as_event=True
- )
- to_add.append(ev)
+ states = await presence_handler.get_states(users)
+ to_add.extend(
+ {
+ "type": EventTypes.Presence,
+ "content": format_user_presence_state(state, time_now),
+ }
+ for state in states
+ )
events.extend(to_add)
- time_now = self.clock.time_msec()
-
chunks = await self._event_serializer.serialize_events(
events,
time_now,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 38ab6a8fc3..81d859f807 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.event_auth import auth_types_for_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
+from synapse.handlers._base import BaseHandler
from synapse.logging.context import (
make_deferred_yieldable,
nested_logging_context,
@@ -69,10 +70,9 @@ from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
+from synapse.util.stringutils import shortstr
from synapse.visibility import filter_events_for_server
-from ._base import BaseHandler
-
logger = logging.getLogger(__name__)
@@ -93,27 +93,6 @@ class _NewEventInfo:
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
-def shortstr(iterable, maxitems=5):
- """If iterable has maxitems or fewer, return the stringification of a list
- containing those items.
-
- Otherwise, return the stringification of a a list with the first maxitems items,
- followed by "...".
-
- Args:
- iterable (Iterable): iterable to truncate
- maxitems (int): number of items to return before truncating
-
- Returns:
- unicode
- """
-
- items = list(itertools.islice(iterable, maxitems + 1))
- if len(items) <= maxitems:
- return str(items)
- return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
-
-
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
@@ -364,7 +343,7 @@ class FederationHandler(BaseHandler):
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
- state_maps = list(ours.values()) # type: list[StateMap[str]]
+ state_maps = list(ours.values()) # type: List[StateMap[str]]
# we don't need this any more, let's delete it.
del ours
@@ -1715,16 +1694,15 @@ class FederationHandler(BaseHandler):
return None
- @defer.inlineCallbacks
- def get_state_for_pdu(self, room_id, event_id):
+ async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
+ state_groups = await self.state_store.get_state_groups(room_id, [event_id])
if state_groups:
_, state = list(iteritems(state_groups)).pop()
@@ -1735,7 +1713,7 @@ class FederationHandler(BaseHandler):
if "replaces_state" in event.unsigned:
prev_id = event.unsigned["replaces_state"]
if prev_id != event.event_id:
- prev_event = yield self.store.get_event(prev_id)
+ prev_event = await self.store.get_event(prev_id)
results[(event.type, event.state_key)] = prev_event
else:
del results[(event.type, event.state_key)]
@@ -1745,15 +1723,14 @@ class FederationHandler(BaseHandler):
else:
return []
- @defer.inlineCallbacks
- def get_state_ids_for_pdu(self, room_id, event_id):
+ async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
- state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
+ state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
if state_groups:
_, state = list(state_groups.items()).pop()
@@ -1772,49 +1749,50 @@ class FederationHandler(BaseHandler):
else:
return []
- @defer.inlineCallbacks
@log_function
- def on_backfill_request(self, origin, room_id, pdu_list, limit):
- in_room = yield self.auth.check_host_in_room(room_id, origin)
+ async def on_backfill_request(
+ self, origin: str, room_id: str, pdu_list: List[str], limit: int
+ ) -> List[EventBase]:
+ in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
# Synapse asks for 100 events per backfill request. Do not allow more.
limit = min(limit, 100)
- events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
+ events = await self.store.get_backfill_events(room_id, pdu_list, limit)
- events = yield filter_events_for_server(self.storage, origin, events)
+ events = await filter_events_for_server(self.storage, origin, events)
return events
- @defer.inlineCallbacks
@log_function
- def get_persisted_pdu(self, origin, event_id):
+ async def get_persisted_pdu(
+ self, origin: str, event_id: str
+ ) -> Optional[EventBase]:
"""Get an event from the database for the given server.
Args:
- origin [str]: hostname of server which is requesting the event; we
+ origin: hostname of server which is requesting the event; we
will check that the server is allowed to see it.
- event_id [str]: id of the event being requested
+ event_id: id of the event being requested
Returns:
- Deferred[EventBase|None]: None if we know nothing about the event;
- otherwise the (possibly-redacted) event.
+ None if we know nothing about the event; otherwise the (possibly-redacted) event.
Raises:
AuthError if the server is not currently in the room
"""
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
if event:
- in_room = yield self.auth.check_host_in_room(event.room_id, origin)
+ in_room = await self.auth.check_host_in_room(event.room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
- events = yield filter_events_for_server(self.storage, origin, [event])
+ events = await filter_events_for_server(self.storage, origin, [event])
event = events[0]
return event
else:
@@ -2418,7 +2396,7 @@ class FederationHandler(BaseHandler):
"""
# exclude the state key of the new event from the current_state in the context.
if event.is_state():
- event_key = (event.type, event.state_key)
+ event_key = (event.type, event.state_key) # type: Optional[Tuple[str, str]]
else:
event_key = None
state_updates = {
@@ -2584,9 +2562,8 @@ class FederationHandler(BaseHandler):
"missing": [e.event_id for e in missing_locals],
}
- @defer.inlineCallbacks
@log_function
- def exchange_third_party_invite(
+ async def exchange_third_party_invite(
self, sender_user_id, target_user_id, room_id, signed
):
third_party_invite = {"signed": signed}
@@ -2602,16 +2579,16 @@ class FederationHandler(BaseHandler):
"state_key": target_user_id,
}
- if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
- room_version = yield self.store.get_room_version_id(room_id)
+ if await self.auth.check_host_in_room(room_id, self.hs.hostname):
+ room_version = await self.store.get_room_version_id(room_id)
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
- event, context = yield self.event_creation_handler.create_new_client_event(
+ event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -2623,7 +2600,7 @@ class FederationHandler(BaseHandler):
403, "This event is not allowed in this context", Codes.FORBIDDEN
)
- event, context = yield self.add_display_name_to_third_party_invite(
+ event, context = await self.add_display_name_to_third_party_invite(
room_version, event_dict, event, context
)
@@ -2634,19 +2611,19 @@ class FederationHandler(BaseHandler):
event.internal_metadata.send_on_behalf_of = self.hs.hostname
try:
- yield self.auth.check_from_context(room_version, event, context)
+ await self.auth.check_from_context(room_version, event, context)
except AuthError as e:
logger.warning("Denying new third party invite %r because %s", event, e)
raise e
- yield self._check_signature(event, context)
+ await self._check_signature(event, context)
# We retrieve the room member handler here as to not cause a cyclic dependency
member_handler = self.hs.get_room_member_handler()
- yield member_handler.send_membership_event(None, event, context)
+ await member_handler.send_membership_event(None, event, context)
else:
destinations = {x.split(":", 1)[-1] for x in (sender_user_id, room_id)}
- yield self.federation_client.forward_third_party_invite(
+ await self.federation_client.forward_third_party_invite(
destinations, room_id, event_dict
)
@@ -2704,8 +2681,7 @@ class FederationHandler(BaseHandler):
member_handler = self.hs.get_room_member_handler()
await member_handler.send_membership_event(None, event, context)
- @defer.inlineCallbacks
- def add_display_name_to_third_party_invite(
+ async def add_display_name_to_third_party_invite(
self, room_version, event_dict, event, context
):
key = (
@@ -2713,10 +2689,10 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
- original_invite = yield self.store.get_event(
+ original_invite = await self.store.get_event(
original_invite_id, allow_none=True
)
if original_invite:
@@ -2737,14 +2713,13 @@ class FederationHandler(BaseHandler):
builder = self.event_builder_factory.new(room_version, event_dict)
EventValidator().validate_builder(builder)
- event, context = yield self.event_creation_handler.create_new_client_event(
+ event, context = await self.event_creation_handler.create_new_client_event(
builder=builder
)
EventValidator().validate_new(event, self.config)
return (event, context)
- @defer.inlineCallbacks
- def _check_signature(self, event, context):
+ async def _check_signature(self, event, context):
"""
Checks that the signature in the event is consistent with its invite.
@@ -2761,12 +2736,12 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None
if invite_event_id:
- invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+ invite_event = await self.store.get_event(invite_event_id, allow_none=True)
if not invite_event:
raise AuthError(403, "Could not find invite")
@@ -2815,7 +2790,7 @@ class FederationHandler(BaseHandler):
raise
try:
if "key_validity_url" in public_key_object:
- yield self._check_key_revocation(
+ await self._check_key_revocation(
public_key, public_key_object["key_validity_url"]
)
except Exception:
@@ -2829,8 +2804,7 @@ class FederationHandler(BaseHandler):
last_exception = e
raise last_exception
- @defer.inlineCallbacks
- def _check_key_revocation(self, public_key, url):
+ async def _check_key_revocation(self, public_key, url):
"""
Checks whether public_key has been revoked.
@@ -2844,7 +2818,7 @@ class FederationHandler(BaseHandler):
for revocation.
"""
try:
- response = yield self.http_client.get_json(url, {"public_key": public_key})
+ response = await self.http_client.get_json(url, {"public_key": public_key})
except Exception:
raise SynapseError(502, "Third party certificate could not be checked")
if "valid" not in response or not response["valid"]:
@@ -2939,8 +2913,7 @@ class FederationHandler(BaseHandler):
else:
user_joined_room(self.distributor, user, room_id)
- @defer.inlineCallbacks
- def get_room_complexity(self, remote_room_hosts, room_id):
+ async def get_room_complexity(self, remote_room_hosts, room_id):
"""
Fetch the complexity of a remote room over federation.
@@ -2954,12 +2927,12 @@ class FederationHandler(BaseHandler):
"""
for host in remote_room_hosts:
- res = yield self.federation_client.get_room_complexity(host, room_id)
+ res = await self.federation_client.get_room_complexity(host, room_id)
# We got a result, return it.
if res:
- defer.returnValue(res)
+ return res
# We fell off the bottom, couldn't get the complexity from anyone. Oh
# well.
- defer.returnValue(None)
+ return None
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index ad22415782..ca5c83811a 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -284,15 +284,14 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
- @defer.inlineCallbacks
- def create_group(self, group_id, user_id, content):
+ async def create_group(self, group_id, user_id, content):
"""Create a group
"""
logger.info("Asking to create group with ID: %r", group_id)
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.create_group(
+ res = await self.groups_server_handler.create_group(
group_id, user_id, content
)
local_attestation = None
@@ -301,10 +300,10 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
- content["user_profile"] = yield self.profile_handler.get_profile(user_id)
+ content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
- res = yield self.transport_client.create_group(
+ res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
@@ -313,7 +312,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
- yield self.attestations.verify_attestation(
+ await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
@@ -321,7 +320,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
)
is_publicised = content.get("publicise", False)
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id,
user_id,
membership="join",
@@ -482,12 +481,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
- @defer.inlineCallbacks
- def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
+ async def remove_user_from_group(
+ self, group_id, user_id, requester_user_id, content
+ ):
"""Remove a user from a group
"""
if user_id == requester_user_id:
- token = yield self.store.register_user_group_membership(
+ token = await self.store.register_user_group_membership(
group_id, user_id, membership="leave"
)
self.notifier.on_new_event("groups_key", token, users=[user_id])
@@ -496,13 +496,13 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
# retry if the group server is currently down.
if self.is_mine_id(group_id):
- res = yield self.groups_server_handler.remove_user_from_group(
+ res = await self.groups_server_handler.remove_user_from_group(
group_id, user_id, requester_user_id, content
)
else:
content["requester_user_id"] = requester_user_id
try:
- res = yield self.transport_client.remove_user_from_group(
+ res = await self.transport_client.remove_user_from_group(
get_domain_from_id(group_id),
group_id,
requester_user_id,
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 23f07832e7..0f0e632b62 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -18,7 +18,7 @@
"""Utilities for interacting with Identity Servers"""
import logging
-import urllib
+import urllib.parse
from canonicaljson import json
from signedjson.key import decode_verify_key_bytes
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index b116500c7d..f88bad5f25 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -381,10 +381,16 @@ class InitialSyncHandler(BaseHandler):
return []
states = await presence_handler.get_states(
- [m.user_id for m in room_members], as_event=True
+ [m.user_id for m in room_members]
)
- return states
+ return [
+ {
+ "type": EventTypes.Presence,
+ "content": format_user_presence_state(s, time_now),
+ }
+ for s in states
+ ]
async def get_receipts():
receipts = await self.store.get_linearized_receipts_for_room(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index b743fc2dcc..0242521cc6 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -72,7 +72,6 @@ class MessageHandler(object):
self.state_store = self.storage.state
self._event_serializer = hs.get_event_client_serializer()
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
- self._is_worker_app = bool(hs.config.worker_app)
# The scheduled call to self._expire_event. None if no call is currently
# scheduled.
@@ -260,7 +259,6 @@ class MessageHandler(object):
Args:
event (EventBase): The event to schedule the expiry of.
"""
- assert not self._is_worker_app
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
if not isinstance(expiry_ts, int) or event.is_state():
@@ -419,6 +417,8 @@ class EventCreationHandler(object):
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
+ self._dummy_events_threshold = hs.config.dummy_events_threshold
+
@defer.inlineCallbacks
def create_event(
self,
@@ -626,8 +626,7 @@ class EventCreationHandler(object):
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
- @defer.inlineCallbacks
- def send_nonmember_event(self, requester, event, context, ratelimit=True):
+ async def send_nonmember_event(self, requester, event, context, ratelimit=True):
"""
Persists and notifies local clients and federation of an event.
@@ -647,7 +646,7 @@ class EventCreationHandler(object):
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
if event.is_state():
- prev_state = yield self.deduplicate_state_event(event, context)
+ prev_state = await self.deduplicate_state_event(event, context)
if prev_state is not None:
logger.info(
"Not bothering to persist state event %s duplicated by %s",
@@ -656,7 +655,7 @@ class EventCreationHandler(object):
)
return prev_state
- yield self.handle_new_client_event(
+ await self.handle_new_client_event(
requester=requester, event=event, context=context, ratelimit=ratelimit
)
@@ -683,8 +682,7 @@ class EventCreationHandler(object):
return prev_event
return
- @defer.inlineCallbacks
- def create_and_send_nonmember_event(
+ async def create_and_send_nonmember_event(
self, requester, event_dict, ratelimit=True, txn_id=None
):
"""
@@ -698,8 +696,8 @@ class EventCreationHandler(object):
# a situation where event persistence can't keep up, causing
# extremities to pile up, which in turn leads to state resolution
# taking longer.
- with (yield self.limiter.queue(event_dict["room_id"])):
- event, context = yield self.create_event(
+ with (await self.limiter.queue(event_dict["room_id"])):
+ event, context = await self.create_event(
requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
)
@@ -709,7 +707,7 @@ class EventCreationHandler(object):
spam_error = "Spam is not permitted here"
raise SynapseError(403, spam_error, Codes.FORBIDDEN)
- yield self.send_nonmember_event(
+ await self.send_nonmember_event(
requester, event, context, ratelimit=ratelimit
)
return event
@@ -770,8 +768,7 @@ class EventCreationHandler(object):
return (event, context)
@measure_func("handle_new_client_event")
- @defer.inlineCallbacks
- def handle_new_client_event(
+ async def handle_new_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Processes a new event. This includes checking auth, persisting it,
@@ -794,9 +791,9 @@ class EventCreationHandler(object):
):
room_version = event.content.get("room_version", RoomVersions.V1.identifier)
else:
- room_version = yield self.store.get_room_version_id(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
- event_allowed = yield self.third_party_event_rules.check_event_allowed(
+ event_allowed = await self.third_party_event_rules.check_event_allowed(
event, context
)
if not event_allowed:
@@ -805,7 +802,7 @@ class EventCreationHandler(object):
)
try:
- yield self.auth.check_from_context(room_version, event, context)
+ await self.auth.check_from_context(room_version, event, context)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
@@ -818,7 +815,7 @@ class EventCreationHandler(object):
logger.exception("Failed to encode content: %r", event.content)
raise
- yield self.action_generator.handle_push_actions_for_event(event, context)
+ await self.action_generator.handle_push_actions_for_event(event, context)
# reraise does not allow inlineCallbacks to preserve the stacktrace, so we
# hack around with a try/finally instead.
@@ -826,7 +823,7 @@ class EventCreationHandler(object):
try:
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
- yield self.send_event_to_master(
+ await self.send_event_to_master(
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -838,7 +835,7 @@ class EventCreationHandler(object):
success = True
return
- yield self.persist_and_notify_client_event(
+ await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
)
@@ -852,7 +849,38 @@ class EventCreationHandler(object):
)
@defer.inlineCallbacks
- def persist_and_notify_client_event(
+ def _validate_canonical_alias(
+ self, directory_handler, room_alias_str, expected_room_id
+ ):
+ """
+ Ensure that the given room alias points to the expected room ID.
+
+ Args:
+ directory_handler: The directory handler object.
+ room_alias_str: The room alias to check.
+ expected_room_id: The room ID that the alias should point to.
+ """
+ room_alias = RoomAlias.from_string(room_alias_str)
+ try:
+ mapping = yield directory_handler.get_association(room_alias)
+ except SynapseError as e:
+ # Turn M_NOT_FOUND errors into M_BAD_ALIAS errors.
+ if e.errcode == Codes.NOT_FOUND:
+ raise SynapseError(
+ 400,
+ "Room alias %s does not point to the room" % (room_alias_str,),
+ Codes.BAD_ALIAS,
+ )
+ raise
+
+ if mapping["room_id"] != expected_room_id:
+ raise SynapseError(
+ 400,
+ "Room alias %s does not point to the room" % (room_alias_str,),
+ Codes.BAD_ALIAS,
+ )
+
+ async def persist_and_notify_client_event(
self, requester, event, context, ratelimit=True, extra_users=[]
):
"""Called when we have fully built the event, have already
@@ -869,7 +897,7 @@ class EventCreationHandler(object):
# user is actually admin or not).
is_admin_redaction = False
if event.type == EventTypes.Redaction:
- original_event = yield self.store.get_event(
+ original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@@ -881,11 +909,11 @@ class EventCreationHandler(object):
original_event and event.sender != original_event.sender
)
- yield self.base_handler.ratelimit(
+ await self.base_handler.ratelimit(
requester, is_admin_redaction=is_admin_redaction
)
- yield self.base_handler.maybe_kick_guest_users(event, context)
+ await self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
# Validate a newly added alias or newly added alt_aliases.
@@ -895,7 +923,7 @@ class EventCreationHandler(object):
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
- original_event = yield self.store.get_event(original_event_id)
+ original_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
@@ -905,15 +933,9 @@ class EventCreationHandler(object):
room_alias_str = event.content.get("alias", None)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias_str and room_alias_str != original_alias:
- room_alias = RoomAlias.from_string(room_alias_str)
- mapping = yield directory_handler.get_association(room_alias)
-
- if mapping["room_id"] != event.room_id:
- raise SynapseError(
- 400,
- "Room alias %s does not point to the room" % (room_alias_str,),
- Codes.BAD_ALIAS,
- )
+ await self._validate_canonical_alias(
+ directory_handler, room_alias_str, event.room_id
+ )
# Check that alt_aliases is the proper form.
alt_aliases = event.content.get("alt_aliases", [])
@@ -931,16 +953,9 @@ class EventCreationHandler(object):
new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
if new_alt_aliases:
for alias_str in new_alt_aliases:
- room_alias = RoomAlias.from_string(alias_str)
- mapping = yield directory_handler.get_association(room_alias)
-
- if mapping["room_id"] != event.room_id:
- raise SynapseError(
- 400,
- "Room alias %s does not point to the room"
- % (room_alias_str,),
- Codes.BAD_ALIAS,
- )
+ await self._validate_canonical_alias(
+ directory_handler, alias_str, event.room_id
+ )
federation_handler = self.hs.get_handlers().federation_handler
@@ -950,7 +965,7 @@ class EventCreationHandler(object):
def is_inviter_member_event(e):
return e.type == EventTypes.Member and e.sender == event.sender
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = await context.get_current_state_ids()
state_to_include_ids = [
e_id
@@ -959,7 +974,7 @@ class EventCreationHandler(object):
or k == (EventTypes.Member, event.sender)
]
- state_to_include = yield self.store.get_events(state_to_include_ids)
+ state_to_include = await self.store.get_events(state_to_include_ids)
event.unsigned["invite_room_state"] = [
{
@@ -977,8 +992,8 @@ class EventCreationHandler(object):
# way? If we have been invited by a remote server, we need
# to get them to sign the event.
- returned_invite = yield defer.ensureDeferred(
- federation_handler.send_invite(invitee.domain, event)
+ returned_invite = await federation_handler.send_invite(
+ invitee.domain, event
)
event.unsigned.pop("room_state", None)
@@ -986,7 +1001,7 @@ class EventCreationHandler(object):
event.signatures.update(returned_invite.signatures)
if event.type == EventTypes.Redaction:
- original_event = yield self.store.get_event(
+ original_event = await self.store.get_event(
event.redacts,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
@@ -1002,14 +1017,14 @@ class EventCreationHandler(object):
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
- prev_state_ids = yield context.get_prev_state_ids()
- auth_events_ids = yield self.auth.compute_auth_events(
+ prev_state_ids = await context.get_prev_state_ids()
+ auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
- auth_events = yield self.store.get_events(auth_events_ids)
+ auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
- room_version = yield self.store.get_room_version_id(event.room_id)
+ room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if event_auth.check_redaction(
@@ -1028,11 +1043,11 @@ class EventCreationHandler(object):
event.internal_metadata.recheck_redaction = False
if event.type == EventTypes.Create:
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
if prev_state_ids:
raise AuthError(403, "Changing the room create event is forbidden")
- event_stream_id, max_stream_id = yield self.storage.persistence.persist_event(
+ event_stream_id, max_stream_id = await self.storage.persistence.persist_event(
event, context=context
)
@@ -1040,7 +1055,7 @@ class EventCreationHandler(object):
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
- yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+ await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
def _notify():
try:
@@ -1064,14 +1079,13 @@ class EventCreationHandler(object):
except Exception:
logger.exception("Error bumping presence active time")
- @defer.inlineCallbacks
- def _send_dummy_events_to_fill_extremities(self):
+ async def _send_dummy_events_to_fill_extremities(self):
"""Background task to send dummy events into rooms that have a large
number of extremities
"""
self._expire_rooms_to_exclude_from_dummy_event_insertion()
- room_ids = yield self.store.get_rooms_with_many_extremities(
- min_count=10,
+ room_ids = await self.store.get_rooms_with_many_extremities(
+ min_count=self._dummy_events_threshold,
limit=5,
room_id_filter=self._rooms_to_exclude_from_dummy_event_insertion.keys(),
)
@@ -1080,9 +1094,9 @@ class EventCreationHandler(object):
# For each room we need to find a joined member we can use to send
# the dummy event with.
- latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- members = yield self.state.get_current_users_in_room(
+ members = await self.state.get_current_users_in_room(
room_id, latest_event_ids=latest_event_ids
)
dummy_event_sent = False
@@ -1091,7 +1105,7 @@ class EventCreationHandler(object):
continue
requester = create_requester(user_id)
try:
- event, context = yield self.create_event(
+ event, context = await self.create_event(
requester,
{
"type": "org.matrix.dummy_event",
@@ -1104,7 +1118,7 @@ class EventCreationHandler(object):
event.internal_metadata.proactively_send = False
- yield self.send_nonmember_event(
+ await self.send_nonmember_event(
requester, event, context, ratelimit=False
)
dummy_event_sent = True
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
new file mode 100644
index 0000000000..178f263439
--- /dev/null
+++ b/synapse/handlers/oidc_handler.py
@@ -0,0 +1,998 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import json
+import logging
+from typing import Dict, Generic, List, Optional, Tuple, TypeVar
+from urllib.parse import urlencode
+
+import attr
+import pymacaroons
+from authlib.common.security import generate_token
+from authlib.jose import JsonWebToken
+from authlib.oauth2.auth import ClientAuth
+from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
+from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
+from authlib.oidc.discovery import OpenIDProviderMetadata, get_well_known_url
+from jinja2 import Environment, Template
+from pymacaroons.exceptions import (
+ MacaroonDeserializationException,
+ MacaroonInvalidSignatureException,
+)
+from typing_extensions import TypedDict
+
+from twisted.web.client import readBody
+
+from synapse.config import ConfigError
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
+from synapse.push.mailer import load_jinja2_templates
+from synapse.server import HomeServer
+from synapse.types import UserID, map_username_to_mxid_localpart
+
+logger = logging.getLogger(__name__)
+
+SESSION_COOKIE_NAME = b"oidc_session"
+
+#: A token exchanged from the token endpoint, as per RFC6749 sec 5.1. and
+#: OpenID.Core sec 3.1.3.3.
+Token = TypedDict(
+ "Token",
+ {
+ "access_token": str,
+ "token_type": str,
+ "id_token": Optional[str],
+ "refresh_token": Optional[str],
+ "expires_in": int,
+ "scope": Optional[str],
+ },
+)
+
+#: A JWK, as per RFC7517 sec 4. The type could be more precise than that, but
+#: there is no real point of doing this in our case.
+JWK = Dict[str, str]
+
+#: A JWK Set, as per RFC7517 sec 5.
+JWKS = TypedDict("JWKS", {"keys": List[JWK]})
+
+
+class OidcError(Exception):
+ """Used to catch errors when calling the token_endpoint
+ """
+
+ def __init__(self, error, error_description=None):
+ self.error = error
+ self.error_description = error_description
+
+ def __str__(self):
+ if self.error_description:
+ return "{}: {}".format(self.error, self.error_description)
+ return self.error
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping the UserInfo object
+ """
+
+
+class OidcHandler:
+ """Handles requests related to the OpenID Connect login flow.
+ """
+
+ def __init__(self, hs: HomeServer):
+ self._callback_url = hs.config.oidc_callback_url # type: str
+ self._scopes = hs.config.oidc_scopes # type: List[str]
+ self._client_auth = ClientAuth(
+ hs.config.oidc_client_id,
+ hs.config.oidc_client_secret,
+ hs.config.oidc_client_auth_method,
+ ) # type: ClientAuth
+ self._client_auth_method = hs.config.oidc_client_auth_method # type: str
+ self._subject_claim = hs.config.oidc_subject_claim
+ self._provider_metadata = OpenIDProviderMetadata(
+ issuer=hs.config.oidc_issuer,
+ authorization_endpoint=hs.config.oidc_authorization_endpoint,
+ token_endpoint=hs.config.oidc_token_endpoint,
+ userinfo_endpoint=hs.config.oidc_userinfo_endpoint,
+ jwks_uri=hs.config.oidc_jwks_uri,
+ ) # type: OpenIDProviderMetadata
+ self._provider_needs_discovery = hs.config.oidc_discover # type: bool
+ self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class(
+ hs.config.oidc_user_mapping_provider_config
+ ) # type: OidcMappingProvider
+ self._skip_verification = hs.config.oidc_skip_verification # type: bool
+
+ self._http_client = hs.get_proxied_http_client()
+ self._auth_handler = hs.get_auth_handler()
+ self._registration_handler = hs.get_registration_handler()
+ self._datastore = hs.get_datastore()
+ self._clock = hs.get_clock()
+ self._hostname = hs.hostname # type: str
+ self._server_name = hs.config.server_name # type: str
+ self._macaroon_secret_key = hs.config.macaroon_secret_key
+ self._error_template = load_jinja2_templates(
+ hs.config.sso_template_dir, ["sso_error.html"]
+ )[0]
+
+ # identifier for the external_ids table
+ self._auth_provider_id = "oidc"
+
+ def _render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and respond with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under ``synapse/res/templates/sso_error.html``.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error. Those include
+ well-known OAuth2/OIDC error types like invalid_request or
+ access_denied.
+ error_description: A human-readable description of the error.
+ """
+ html_bytes = self._error_template.render(
+ error=error, error_description=error_description
+ ).encode("utf-8")
+
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%i" % len(html_bytes))
+ request.write(html_bytes)
+ finish_request(request)
+
+ def _validate_metadata(self):
+ """Verifies the provider metadata.
+
+ This checks the validity of the currently loaded provider. Not
+ everything is checked, only:
+
+ - ``issuer``
+ - ``authorization_endpoint``
+ - ``token_endpoint``
+ - ``response_types_supported`` (checks if "code" is in it)
+ - ``jwks_uri``
+
+ Raises:
+ ValueError: if something in the provider is not valid
+ """
+ # Skip verification to allow non-compliant providers (e.g. issuers not running on a secure origin)
+ if self._skip_verification is True:
+ return
+
+ m = self._provider_metadata
+ m.validate_issuer()
+ m.validate_authorization_endpoint()
+ m.validate_token_endpoint()
+
+ if m.get("token_endpoint_auth_methods_supported") is not None:
+ m.validate_token_endpoint_auth_methods_supported()
+ if (
+ self._client_auth_method
+ not in m["token_endpoint_auth_methods_supported"]
+ ):
+ raise ValueError(
+ '"{auth_method}" not in "token_endpoint_auth_methods_supported" ({supported!r})'.format(
+ auth_method=self._client_auth_method,
+ supported=m["token_endpoint_auth_methods_supported"],
+ )
+ )
+
+ if m.get("response_types_supported") is not None:
+ m.validate_response_types_supported()
+
+ if "code" not in m["response_types_supported"]:
+ raise ValueError(
+ '"code" not in "response_types_supported" (%r)'
+ % (m["response_types_supported"],)
+ )
+
+ # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+ if self._uses_userinfo:
+ if m.get("userinfo_endpoint") is None:
+ raise ValueError(
+ 'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+ )
+ else:
+ # If we're not using userinfo, we need a valid jwks to validate the ID token
+ if m.get("jwks") is None:
+ if m.get("jwks_uri") is not None:
+ m.validate_jwks_uri()
+ else:
+ raise ValueError('"jwks_uri" must be set')
+
+ @property
+ def _uses_userinfo(self) -> bool:
+ """Returns True if the ``userinfo_endpoint`` should be used.
+
+ This is based on the requested scopes: if the scopes include
+ ``openid``, the provider should give use an ID token containing the
+ user informations. If not, we should fetch them using the
+ ``access_token`` with the ``userinfo_endpoint``.
+ """
+
+ # Maybe that should be user-configurable and not inferred?
+ return "openid" not in self._scopes
+
+ async def load_metadata(self) -> OpenIDProviderMetadata:
+ """Load and validate the provider metadata.
+
+ The values metadatas are discovered if ``oidc_config.discovery`` is
+ ``True`` and then cached.
+
+ Raises:
+ ValueError: if something in the provider is not valid
+
+ Returns:
+ The provider's metadata.
+ """
+ # If we are using the OpenID Discovery documents, it needs to be loaded once
+ # FIXME: should there be a lock here?
+ if self._provider_needs_discovery:
+ url = get_well_known_url(self._provider_metadata["issuer"], external=True)
+ metadata_response = await self._http_client.get_json(url)
+ # TODO: maybe update the other way around to let user override some values?
+ self._provider_metadata.update(metadata_response)
+ self._provider_needs_discovery = False
+
+ self._validate_metadata()
+
+ return self._provider_metadata
+
+ async def load_jwks(self, force: bool = False) -> JWKS:
+ """Load the JSON Web Key Set used to sign ID tokens.
+
+ If we're not using the ``userinfo_endpoint``, user infos are extracted
+ from the ID token, which is a JWT signed by keys given by the provider.
+ The keys are then cached.
+
+ Args:
+ force: Force reloading the keys.
+
+ Returns:
+ The key set
+
+ Looks like this::
+
+ {
+ 'keys': [
+ {
+ 'kid': 'abcdef',
+ 'kty': 'RSA',
+ 'alg': 'RS256',
+ 'use': 'sig',
+ 'e': 'XXXX',
+ 'n': 'XXXX',
+ }
+ ]
+ }
+ """
+ if self._uses_userinfo:
+ # We're not using jwt signing, return an empty jwk set
+ return {"keys": []}
+
+ # First check if the JWKS are loaded in the provider metadata.
+ # It can happen either if the provider gives its JWKS in the discovery
+ # document directly or if it was already loaded once.
+ metadata = await self.load_metadata()
+ jwk_set = metadata.get("jwks")
+ if jwk_set is not None and not force:
+ return jwk_set
+
+ # Loading the JWKS using the `jwks_uri` metadata
+ uri = metadata.get("jwks_uri")
+ if not uri:
+ raise RuntimeError('Missing "jwks_uri" in metadata')
+
+ jwk_set = await self._http_client.get_json(uri)
+
+ # Caching the JWKS in the provider's metadata
+ self._provider_metadata["jwks"] = jwk_set
+ return jwk_set
+
+ async def _exchange_code(self, code: str) -> Token:
+ """Exchange an authorization code for a token.
+
+ This calls the ``token_endpoint`` with the authorization code we
+ received in the callback to exchange it for a token. The call uses the
+ ``ClientAuth`` to authenticate with the client with its ID and secret.
+
+ Args:
+ code: The autorization code we got from the callback.
+
+ Returns:
+ A dict containing various tokens.
+
+ May look like this::
+
+ {
+ 'token_type': 'bearer',
+ 'access_token': 'abcdef',
+ 'expires_in': 3599,
+ 'id_token': 'ghijkl',
+ 'refresh_token': 'mnopqr',
+ }
+
+ Raises:
+ OidcError: when the ``token_endpoint`` returned an error.
+ """
+ metadata = await self.load_metadata()
+ token_endpoint = metadata.get("token_endpoint")
+ headers = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "User-Agent": self._http_client.user_agent,
+ "Accept": "application/json",
+ }
+
+ args = {
+ "grant_type": "authorization_code",
+ "code": code,
+ "redirect_uri": self._callback_url,
+ }
+ body = urlencode(args, True)
+
+ # Fill the body/headers with credentials
+ uri, headers, body = self._client_auth.prepare(
+ method="POST", uri=token_endpoint, headers=headers, body=body
+ )
+ headers = {k: [v] for (k, v) in headers.items()}
+
+ # Do the actual request
+ # We're not using the SimpleHttpClient util methods as we don't want to
+ # check the HTTP status code and we do the body encoding ourself.
+ response = await self._http_client.request(
+ method="POST", uri=uri, data=body.encode("utf-8"), headers=headers,
+ )
+
+ # This is used in multiple error messages below
+ status = "{code} {phrase}".format(
+ code=response.code, phrase=response.phrase.decode("utf-8")
+ )
+
+ resp_body = await readBody(response)
+
+ if response.code >= 500:
+ # In case of a server error, we should first try to decode the body
+ # and check for an error field. If not, we respond with a generic
+ # error message.
+ try:
+ resp = json.loads(resp_body.decode("utf-8"))
+ error = resp["error"]
+ description = resp.get("error_description", error)
+ except (ValueError, KeyError):
+ # Catch ValueError for the JSON decoding and KeyError for the "error" field
+ error = "server_error"
+ description = (
+ (
+ 'Authorization server responded with a "{status}" error '
+ "while exchanging the authorization code."
+ ).format(status=status),
+ )
+
+ raise OidcError(error, description)
+
+ # Since it is a not a 5xx code, body should be a valid JSON. It will
+ # raise if not.
+ resp = json.loads(resp_body.decode("utf-8"))
+
+ if "error" in resp:
+ error = resp["error"]
+ # In case the authorization server responded with an error field,
+ # it should be a 4xx code. If not, warn about it but don't do
+ # anything special and report the original error message.
+ if response.code < 400:
+ logger.debug(
+ "Invalid response from the authorization server: "
+ 'responded with a "{status}" '
+ "but body has an error field: {error!r}".format(
+ status=status, error=resp["error"]
+ )
+ )
+
+ description = resp.get("error_description", error)
+ raise OidcError(error, description)
+
+ # Now, this should not be an error. According to RFC6749 sec 5.1, it
+ # should be a 200 code. We're a bit more flexible than that, and will
+ # only throw on a 4xx code.
+ if response.code >= 400:
+ description = (
+ 'Authorization server responded with a "{status}" error '
+ 'but did not include an "error" field in its response.'.format(
+ status=status
+ )
+ )
+ logger.warning(description)
+ # Body was still valid JSON. Might be useful to log it for debugging.
+ logger.warning("Code exchange response: {resp!r}".format(resp=resp))
+ raise OidcError("server_error", description)
+
+ return resp
+
+ async def _fetch_userinfo(self, token: Token) -> UserInfo:
+ """Fetch user informations from the ``userinfo_endpoint``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``access_token`` field.
+
+ Returns:
+ UserInfo: an object representing the user.
+ """
+ metadata = await self.load_metadata()
+
+ resp = await self._http_client.get_json(
+ metadata["userinfo_endpoint"],
+ headers={"Authorization": ["Bearer {}".format(token["access_token"])]},
+ )
+
+ return UserInfo(resp)
+
+ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
+ """Return an instance of UserInfo from token's ``id_token``.
+
+ Args:
+ token: the token given by the ``token_endpoint``.
+ Must include an ``id_token`` field.
+ nonce: the nonce value originally sent in the initial authorization
+ request. This value should match the one inside the token.
+
+ Returns:
+ An object representing the user.
+ """
+ metadata = await self.load_metadata()
+ claims_params = {
+ "nonce": nonce,
+ "client_id": self._client_auth.client_id,
+ }
+ if "access_token" in token:
+ # If we got an `access_token`, there should be an `at_hash` claim
+ # in the `id_token` that we can check against.
+ claims_params["access_token"] = token["access_token"]
+ claims_cls = CodeIDToken
+ else:
+ claims_cls = ImplicitIDToken
+
+ alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+ jwt = JsonWebToken(alg_values)
+
+ claim_options = {"iss": {"values": [metadata["issuer"]]}}
+
+ # Try to decode the keys in cache first, then retry by forcing the keys
+ # to be reloaded
+ jwk_set = await self.load_jwks()
+ try:
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+ except ValueError:
+ jwk_set = await self.load_jwks(force=True) # try reloading the jwks
+ claims = jwt.decode(
+ token["id_token"],
+ key=jwk_set,
+ claims_cls=claims_cls,
+ claims_options=claim_options,
+ claims_params=claims_params,
+ )
+
+ claims.validate(leeway=120) # allows 2 min of clock skew
+ return UserInfo(claims)
+
+ async def handle_redirect_request(
+ self, request: SynapseRequest, client_redirect_url: bytes
+ ) -> None:
+ """Handle an incoming request to /login/sso/redirect
+
+ It redirects the browser to the authorization endpoint with a few
+ parameters:
+
+ - ``client_id``: the client ID set in ``oidc_config.client_id``
+ - ``response_type``: ``code``
+ - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback``
+ - ``scope``: the list of scopes set in ``oidc_config.scopes``
+ - ``state``: a random string
+ - ``nonce``: a random string
+
+ In addition to redirecting the client, we are setting a cookie with
+ a signed macaroon token containing the state, the nonce and the
+ client_redirect_url params. Those are then checked when the client
+ comes back from the provider.
+
+
+ Args:
+ request: the incoming request from the browser.
+ We'll respond to it with a redirect and a cookie.
+ client_redirect_url: the URL that we should redirect the client to
+ when everything is done
+ """
+
+ state = generate_token()
+ nonce = generate_token()
+
+ cookie = self._generate_oidc_session_token(
+ state=state, nonce=nonce, client_redirect_url=client_redirect_url.decode(),
+ )
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ cookie,
+ path="/_synapse/oidc",
+ max_age="3600",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ metadata = await self.load_metadata()
+ authorization_endpoint = metadata.get("authorization_endpoint")
+ uri = prepare_grant_uri(
+ authorization_endpoint,
+ client_id=self._client_auth.client_id,
+ response_type="code",
+ redirect_uri=self._callback_url,
+ scope=self._scopes,
+ state=state,
+ nonce=nonce,
+ )
+ request.redirect(uri)
+ finish_request(request)
+
+ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
+ """Handle an incoming request to /_synapse/oidc/callback
+
+ Since we might want to display OIDC-related errors in a user-friendly
+ way, we don't raise SynapseError from here. Instead, we call
+ ``self._render_error`` which displays an HTML page for the error.
+
+ Most of the OpenID Connect logic happens here:
+
+ - first, we check if there was any error returned by the provider and
+ display it
+ - then we fetch the session cookie, decode and verify it
+ - the ``state`` query parameter should match with the one stored in the
+ session cookie
+ - once we known this session is legit, exchange the code with the
+ provider using the ``token_endpoint`` (see ``_exchange_code``)
+ - once we have the token, use it to either extract the UserInfo from
+ the ``id_token`` (``_parse_id_token``), or use the ``access_token``
+ to fetch UserInfo from the ``userinfo_endpoint``
+ (``_fetch_userinfo``)
+ - map those UserInfo to a Matrix user (``_map_userinfo_to_user``) and
+ finish the login
+
+ Args:
+ request: the incoming request from the browser.
+ """
+
+ # The provider might redirect with an error.
+ # In that case, just display it as-is.
+ if b"error" in request.args:
+ error = request.args[b"error"][0].decode()
+ description = request.args.get(b"error_description", [b""])[0].decode()
+
+ # Most of the errors returned by the provider could be due by
+ # either the provider misbehaving or Synapse being misconfigured.
+ # The only exception of that is "access_denied", where the user
+ # probably cancelled the login flow. In other cases, log those errors.
+ if error != "access_denied":
+ logger.error("Error from the OIDC provider: %s %s", error, description)
+
+ self._render_error(request, error, description)
+ return
+
+ # Fetch the session cookie
+ session = request.getCookie(SESSION_COOKIE_NAME)
+ if session is None:
+ logger.info("No session cookie found")
+ self._render_error(request, "missing_session", "No session cookie found")
+ return
+
+ # Remove the cookie. There is a good chance that if the callback failed
+ # once, it will fail next time and the code will already be exchanged.
+ # Removing it early avoids spamming the provider with token requests.
+ request.addCookie(
+ SESSION_COOKIE_NAME,
+ b"",
+ path="/_synapse/oidc",
+ expires="Thu, Jan 01 1970 00:00:00 UTC",
+ httpOnly=True,
+ sameSite="lax",
+ )
+
+ # Check for the state query parameter
+ if b"state" not in request.args:
+ logger.info("State parameter is missing")
+ self._render_error(request, "invalid_request", "State parameter is missing")
+ return
+
+ state = request.args[b"state"][0].decode()
+
+ # Deserialize the session token and verify it.
+ try:
+ nonce, client_redirect_url = self._verify_oidc_session_token(session, state)
+ except MacaroonDeserializationException as e:
+ logger.exception("Invalid session")
+ self._render_error(request, "invalid_session", str(e))
+ return
+ except MacaroonInvalidSignatureException as e:
+ logger.exception("Could not verify session")
+ self._render_error(request, "mismatching_session", str(e))
+ return
+
+ # Exchange the code with the provider
+ if b"code" not in request.args:
+ logger.info("Code parameter is missing")
+ self._render_error(request, "invalid_request", "Code parameter is missing")
+ return
+
+ logger.info("Exchanging code")
+ code = request.args[b"code"][0].decode()
+ try:
+ token = await self._exchange_code(code)
+ except OidcError as e:
+ logger.exception("Could not exchange code")
+ self._render_error(request, e.error, e.error_description)
+ return
+
+ # Now that we have a token, get the userinfo, either by decoding the
+ # `id_token` or by fetching the `userinfo_endpoint`.
+ if self._uses_userinfo:
+ logger.info("Fetching userinfo")
+ try:
+ userinfo = await self._fetch_userinfo(token)
+ except Exception as e:
+ logger.exception("Could not fetch userinfo")
+ self._render_error(request, "fetch_error", str(e))
+ return
+ else:
+ logger.info("Extracting userinfo from id_token")
+ try:
+ userinfo = await self._parse_id_token(token, nonce=nonce)
+ except Exception as e:
+ logger.exception("Invalid id_token")
+ self._render_error(request, "invalid_token", str(e))
+ return
+
+ # Call the mapper to register/login the user
+ try:
+ user_id = await self._map_userinfo_to_user(userinfo, token)
+ except MappingException as e:
+ logger.exception("Could not map user")
+ self._render_error(request, "mapping_error", str(e))
+ return
+
+ # and finally complete the login
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url
+ )
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ duration_in_ms: int = (60 * 60 * 1000),
+ ) -> str:
+ """Generates a signed token storing data about an OIDC session.
+
+ When Synapse initiates an authorization flow, it creates a random state
+ and a random nonce. Those parameters are given to the provider and
+ should be verified when the client comes back from the provider.
+ It is also used to store the client_redirect_url, which is used to
+ complete the SSO login flow.
+
+ Args:
+ state: The ``state`` parameter passed to the OIDC provider.
+ nonce: The ``nonce`` parameter passed to the OIDC provider.
+ client_redirect_url: The URL the client gave when it initiated the
+ flow.
+ duration_in_ms: An optional duration for the token in milliseconds.
+ Defaults to an hour.
+
+ Returns:
+ A signed macaroon token with the session informations.
+ """
+ macaroon = pymacaroons.Macaroon(
+ location=self._server_name, identifier="key", key=self._macaroon_secret_key,
+ )
+ macaroon.add_first_party_caveat("gen = 1")
+ macaroon.add_first_party_caveat("type = session")
+ macaroon.add_first_party_caveat("state = %s" % (state,))
+ macaroon.add_first_party_caveat("nonce = %s" % (nonce,))
+ macaroon.add_first_party_caveat(
+ "client_redirect_url = %s" % (client_redirect_url,)
+ )
+ now = self._clock.time_msec()
+ expiry = now + duration_in_ms
+ macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ return macaroon.serialize()
+
+ def _verify_oidc_session_token(self, session: str, state: str) -> Tuple[str, str]:
+ """Verifies and extract an OIDC session token.
+
+ This verifies that a given session token was issued by this homeserver
+ and extract the nonce and client_redirect_url caveats.
+
+ Args:
+ session: The session token to verify
+ state: The state the OIDC provider gave back
+
+ Returns:
+ The nonce and the client_redirect_url for this session
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(session)
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = session")
+ v.satisfy_exact("state = %s" % (state,))
+ v.satisfy_general(lambda c: c.startswith("nonce = "))
+ v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
+ v.satisfy_general(self._verify_expiry)
+
+ v.verify(macaroon, self._macaroon_secret_key)
+
+ # Extract the `nonce` and `client_redirect_url` from the token
+ nonce = self._get_value_from_macaroon(macaroon, "nonce")
+ client_redirect_url = self._get_value_from_macaroon(
+ macaroon, "client_redirect_url"
+ )
+
+ return nonce, client_redirect_url
+
+ def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ Exception: if the caveat was not in the macaroon
+ """
+ prefix = key + " = "
+ for caveat in macaroon.caveats:
+ if caveat.caveat_id.startswith(prefix):
+ return caveat.caveat_id[len(prefix) :]
+ raise Exception("No %s caveat in macaroon" % (key,))
+
+ def _verify_expiry(self, caveat: str) -> bool:
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ now = self._clock.time_msec()
+ return now < expiry
+
+ async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str:
+ """Maps a UserInfo object to a mxid.
+
+ UserInfo should have a claim that uniquely identifies users. This claim
+ is usually `sub`, but can be configured with `oidc_config.subject_claim`.
+ It is then used as an `external_id`.
+
+ If we don't find the user that way, we should register the user,
+ mapping the localpart and the display name from the UserInfo.
+
+ If a user already exists with the mxid we've mapped, raise an exception.
+
+ Args:
+ userinfo: an object representing the user
+ token: a dict with the tokens obtained from the provider
+
+ Raises:
+ MappingException: if there was an error while mapping some properties
+
+ Returns:
+ The mxid of the user
+ """
+ try:
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ except Exception as e:
+ raise MappingException(
+ "Failed to extract subject from OIDC response: %s" % (e,)
+ )
+
+ logger.info(
+ "Looking for existing mapping for user %s:%s",
+ self._auth_provider_id,
+ remote_user_id,
+ )
+
+ registered_user_id = await self._datastore.get_user_by_external_id(
+ self._auth_provider_id, remote_user_id,
+ )
+
+ if registered_user_id is not None:
+ logger.info("Found existing mapping %s", registered_user_id)
+ return registered_user_id
+
+ try:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token
+ )
+ except Exception as e:
+ raise MappingException(
+ "Could not extract user attributes from OIDC response: " + str(e)
+ )
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r", attributes
+ )
+
+ if not attributes["localpart"]:
+ raise MappingException("localpart is empty")
+
+ localpart = map_username_to_mxid_localpart(attributes["localpart"])
+
+ user_id = UserID(localpart, self._hostname)
+ if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()):
+ # This mxid is taken
+ raise MappingException(
+ "mxid '{}' is already taken".format(user_id.to_string())
+ )
+
+ # It's the first time this user is logging in and the mapped mxid was
+ # not taken, register the user
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart, default_display_name=attributes["display_name"],
+ )
+
+ await self._datastore.record_user_external_id(
+ self._auth_provider_id, remote_user_id, registered_user_id,
+ )
+ return registered_user_id
+
+
+UserAttribute = TypedDict(
+ "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+)
+C = TypeVar("C")
+
+
+class OidcMappingProvider(Generic[C]):
+ """A mapping provider maps a UserInfo object to user attributes.
+
+ It should provide the API described by this class.
+ """
+
+ def __init__(self, config: C):
+ """
+ Args:
+ config: A custom config object from this module, parsed by ``parse_config()``
+ """
+
+ @staticmethod
+ def parse_config(config: dict) -> C:
+ """Parse the dict provided by the homeserver's config
+
+ Args:
+ config: A dictionary containing configuration options for this provider
+
+ Returns:
+ A custom config object for this module
+ """
+ raise NotImplementedError()
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ """Get a unique user ID for this user.
+
+ Usually, in an OIDC-compliant scenario, it should be the ``sub`` claim from the UserInfo object.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+
+ Returns:
+ A unique user ID
+ """
+ raise NotImplementedError()
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ """Map a ``UserInfo`` objects into user attributes.
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ token: A dict with the tokens returned by the provider
+
+ Returns:
+ A dict containing the ``localpart`` and (optionally) the ``display_name``
+ """
+ raise NotImplementedError()
+
+
+# Used to clear out "None" values in templates
+def jinja_finalize(thing):
+ return thing if thing is not None else ""
+
+
+env = Environment(finalize=jinja_finalize)
+
+
+@attr.s
+class JinjaOidcMappingConfig:
+ subject_claim = attr.ib() # type: str
+ localpart_template = attr.ib() # type: Template
+ display_name_template = attr.ib() # type: Optional[Template]
+
+
+class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
+ """An implementation of a mapping provider based on Jinja templates.
+
+ This is the default mapping provider.
+ """
+
+ def __init__(self, config: JinjaOidcMappingConfig):
+ self._config = config
+
+ @staticmethod
+ def parse_config(config: dict) -> JinjaOidcMappingConfig:
+ subject_claim = config.get("subject_claim", "sub")
+
+ if "localpart_template" not in config:
+ raise ConfigError(
+ "missing key: oidc_config.user_mapping_provider.config.localpart_template"
+ )
+
+ try:
+ localpart_template = env.from_string(config["localpart_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r"
+ % (e,)
+ )
+
+ display_name_template = None # type: Optional[Template]
+ if "display_name_template" in config:
+ try:
+ display_name_template = env.from_string(config["display_name_template"])
+ except Exception as e:
+ raise ConfigError(
+ "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r"
+ % (e,)
+ )
+
+ return JinjaOidcMappingConfig(
+ subject_claim=subject_claim,
+ localpart_template=localpart_template,
+ display_name_template=display_name_template,
+ )
+
+ def get_remote_user_id(self, userinfo: UserInfo) -> str:
+ return userinfo[self._config.subject_claim]
+
+ async def map_user_attributes(
+ self, userinfo: UserInfo, token: Token
+ ) -> UserAttribute:
+ localpart = self._config.localpart_template.render(user=userinfo).strip()
+
+ display_name = None # type: Optional[str]
+ if self._config.display_name_template is not None:
+ display_name = self._config.display_name_template.render(
+ user=userinfo
+ ).strip()
+
+ if display_name == "":
+ display_name = None
+
+ return UserAttribute(localpart=localpart, display_name=display_name)
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
new file mode 100644
index 0000000000..d06b110269
--- /dev/null
+++ b/synapse/handlers/password_policy.py
@@ -0,0 +1,93 @@
+# -*- coding: utf-8 -*-
+# Copyright 2019 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import re
+
+from synapse.api.errors import Codes, PasswordRefusedError
+
+logger = logging.getLogger(__name__)
+
+
+class PasswordPolicyHandler(object):
+ def __init__(self, hs):
+ self.policy = hs.config.password_policy
+ self.enabled = hs.config.password_policy_enabled
+
+ # Regexps for the spec'd policy parameters.
+ self.regexp_digit = re.compile("[0-9]")
+ self.regexp_symbol = re.compile("[^a-zA-Z0-9]")
+ self.regexp_uppercase = re.compile("[A-Z]")
+ self.regexp_lowercase = re.compile("[a-z]")
+
+ def validate_password(self, password):
+ """Checks whether a given password complies with the server's policy.
+
+ Args:
+ password (str): The password to check against the server's policy.
+
+ Raises:
+ PasswordRefusedError: The password doesn't comply with the server's policy.
+ """
+
+ if not self.enabled:
+ return
+
+ minimum_accepted_length = self.policy.get("minimum_length", 0)
+ if len(password) < minimum_accepted_length:
+ raise PasswordRefusedError(
+ msg=(
+ "The password must be at least %d characters long"
+ % minimum_accepted_length
+ ),
+ errcode=Codes.PASSWORD_TOO_SHORT,
+ )
+
+ if (
+ self.policy.get("require_digit", False)
+ and self.regexp_digit.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one digit",
+ errcode=Codes.PASSWORD_NO_DIGIT,
+ )
+
+ if (
+ self.policy.get("require_symbol", False)
+ and self.regexp_symbol.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one symbol",
+ errcode=Codes.PASSWORD_NO_SYMBOL,
+ )
+
+ if (
+ self.policy.get("require_uppercase", False)
+ and self.regexp_uppercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one uppercase letter",
+ errcode=Codes.PASSWORD_NO_UPPERCASE,
+ )
+
+ if (
+ self.policy.get("require_lowercase", False)
+ and self.regexp_lowercase.search(password) is None
+ ):
+ raise PasswordRefusedError(
+ msg="The password must include at least one lowercase letter",
+ errcode=Codes.PASSWORD_NO_LOWERCASE,
+ )
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 5526015ddb..5cbefae177 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -21,10 +22,10 @@ The methods that define policy are:
- PresenceHandler._handle_timeouts
- should_notify
"""
-
+import abc
import logging
from contextlib import contextmanager
-from typing import Dict, List, Set
+from typing import Dict, Iterable, List, Set
from six import iteritems, itervalues
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.presence import UserPresenceState
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
@@ -99,13 +100,106 @@ EXTERNAL_PROCESS_EXPIRY = 5 * 60 * 1000
assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
-class PresenceHandler(object):
+class BasePresenceHandler(abc.ABC):
+ """Parts of the PresenceHandler that are shared between workers and master"""
+
+ def __init__(self, hs: "synapse.server.HomeServer"):
+ self.clock = hs.get_clock()
+ self.store = hs.get_datastore()
+
+ active_presence = self.store.take_presence_startup_info()
+ self.user_to_current_state = {state.user_id: state for state in active_presence}
+
+ @abc.abstractmethod
+ async def user_syncing(
+ self, user_id: str, affect_presence: bool
+ ) -> ContextManager[None]:
+ """Returns a context manager that should surround any stream requests
+ from the user.
+
+ This allows us to keep track of who is currently streaming and who isn't
+ without having to have timers outside of this module to avoid flickering
+ when users disconnect/reconnect.
+
+ Args:
+ user_id: the user that is starting a sync
+ affect_presence: If false this function will be a no-op.
+ Useful for streams that are not associated with an actual
+ client that is being used by a user.
+ """
+
+ @abc.abstractmethod
+ def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+ """Get an iterable of syncing users on this worker, to send to the presence handler
+
+ This is called when a replication connection is established. It should return
+ a list of user ids, which are then sent as USER_SYNC commands to inform the
+ process handling presence about those users.
+
+ Returns:
+ An iterable of user_id strings.
+ """
+
+ async def get_state(self, target_user: UserID) -> UserPresenceState:
+ results = await self.get_states([target_user.to_string()])
+ return results[0]
+
+ async def get_states(
+ self, target_user_ids: Iterable[str]
+ ) -> List[UserPresenceState]:
+ """Get the presence state for users."""
+
+ updates_d = await self.current_state_for_users(target_user_ids)
+ updates = list(updates_d.values())
+
+ for user_id in set(target_user_ids) - {u.user_id for u in updates}:
+ updates.append(UserPresenceState.default(user_id))
+
+ return updates
+
+ async def current_state_for_users(
+ self, user_ids: Iterable[str]
+ ) -> Dict[str, UserPresenceState]:
+ """Get the current presence state for multiple users.
+
+ Returns:
+ dict: `user_id` -> `UserPresenceState`
+ """
+ states = {
+ user_id: self.user_to_current_state.get(user_id, None)
+ for user_id in user_ids
+ }
+
+ missing = [user_id for user_id, state in iteritems(states) if not state]
+ if missing:
+ # There are things not in our in memory cache. Lets pull them out of
+ # the database.
+ res = await self.store.get_presence_for_users(missing)
+ states.update(res)
+
+ missing = [user_id for user_id, state in iteritems(states) if not state]
+ if missing:
+ new = {
+ user_id: UserPresenceState.default(user_id) for user_id in missing
+ }
+ states.update(new)
+ self.user_to_current_state.update(new)
+
+ return states
+
+ @abc.abstractmethod
+ async def set_state(
+ self, target_user: UserID, state: JsonDict, ignore_status_msg: bool = False
+ ) -> None:
+ """Set the presence state of the user. """
+
+
+class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "synapse.server.HomeServer"):
+ super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
self.server_name = hs.hostname
- self.clock = hs.get_clock()
- self.store = hs.get_datastore()
self.wheel_timer = WheelTimer()
self.notifier = hs.get_notifier()
self.federation = hs.get_federation_sender()
@@ -115,13 +209,6 @@ class PresenceHandler(object):
federation_registry.register_edu_handler("m.presence", self.incoming_presence)
- active_presence = self.store.take_presence_startup_info()
-
- # A dictionary of the current state of users. This is prefilled with
- # non-offline presence from the DB. We should fetch from the DB if
- # we can't find a users presence in here.
- self.user_to_current_state = {state.user_id: state for state in active_presence}
-
LaterGauge(
"synapse_handlers_presence_user_to_current_state_size",
"",
@@ -130,7 +217,7 @@ class PresenceHandler(object):
)
now = self.clock.time_msec()
- for state in active_presence:
+ for state in self.user_to_current_state.values():
self.wheel_timer.insert(
now=now, obj=state.user_id, then=state.last_active_ts + IDLE_TIMER
)
@@ -361,10 +448,18 @@ class PresenceHandler(object):
timers_fired_counter.inc(len(states))
+ syncing_user_ids = {
+ user_id
+ for user_id, count in self.user_to_num_current_syncs.items()
+ if count
+ }
+ for user_ids in self.external_process_to_current_syncs.values():
+ syncing_user_ids.update(user_ids)
+
changes = handle_timeouts(
states,
is_mine_fn=self.is_mine_id,
- syncing_user_ids=self.get_currently_syncing_users(),
+ syncing_user_ids=syncing_user_ids,
now=now,
)
@@ -462,22 +557,9 @@ class PresenceHandler(object):
return _user_syncing()
- def get_currently_syncing_users(self):
- """Get the set of user ids that are currently syncing on this HS.
- Returns:
- set(str): A set of user_id strings.
- """
- if self.hs.config.use_presence:
- syncing_user_ids = {
- user_id
- for user_id, count in self.user_to_num_current_syncs.items()
- if count
- }
- for user_ids in self.external_process_to_current_syncs.values():
- syncing_user_ids.update(user_ids)
- return syncing_user_ids
- else:
- return set()
+ def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
+ # since we are the process handling presence, there is nothing to do here.
+ return []
async def update_external_syncs_row(
self, process_id, user_id, is_syncing, sync_time_msec
@@ -554,34 +636,6 @@ class PresenceHandler(object):
res = await self.current_state_for_users([user_id])
return res[user_id]
- async def current_state_for_users(self, user_ids):
- """Get the current presence state for multiple users.
-
- Returns:
- dict: `user_id` -> `UserPresenceState`
- """
- states = {
- user_id: self.user_to_current_state.get(user_id, None)
- for user_id in user_ids
- }
-
- missing = [user_id for user_id, state in iteritems(states) if not state]
- if missing:
- # There are things not in our in memory cache. Lets pull them out of
- # the database.
- res = await self.store.get_presence_for_users(missing)
- states.update(res)
-
- missing = [user_id for user_id, state in iteritems(states) if not state]
- if missing:
- new = {
- user_id: UserPresenceState.default(user_id) for user_id in missing
- }
- states.update(new)
- self.user_to_current_state.update(new)
-
- return states
-
async def _persist_and_notify(self, states):
"""Persist states in the database, poke the notifier and send to
interested remote servers
@@ -669,40 +723,6 @@ class PresenceHandler(object):
federation_presence_counter.inc(len(updates))
await self._update_states(updates)
- async def get_state(self, target_user, as_event=False):
- results = await self.get_states([target_user.to_string()], as_event=as_event)
-
- return results[0]
-
- async def get_states(self, target_user_ids, as_event=False):
- """Get the presence state for users.
-
- Args:
- target_user_ids (list)
- as_event (bool): Whether to format it as a client event or not.
-
- Returns:
- list
- """
-
- updates = await self.current_state_for_users(target_user_ids)
- updates = list(updates.values())
-
- for user_id in set(target_user_ids) - {u.user_id for u in updates}:
- updates.append(UserPresenceState.default(user_id))
-
- now = self.clock.time_msec()
- if as_event:
- return [
- {
- "type": "m.presence",
- "content": format_user_presence_state(state, now),
- }
- for state in updates
- ]
- else:
- return updates
-
async def set_state(self, target_user, state, ignore_status_msg=False):
"""Set the presence state of the user.
"""
@@ -747,7 +767,7 @@ class PresenceHandler(object):
return False
- async def get_all_presence_updates(self, last_id, current_id):
+ async def get_all_presence_updates(self, last_id, current_id, limit):
"""
Gets a list of presence update rows from between the given stream ids.
Each row has:
@@ -762,7 +782,7 @@ class PresenceHandler(object):
"""
# TODO(markjh): replicate the unpersisted changes.
# This could use the in-memory stores for recent changes.
- rows = await self.store.get_all_presence_updates(last_id, current_id)
+ rows = await self.store.get_all_presence_updates(last_id, current_id, limit)
return rows
def notify_new_event(self):
@@ -889,7 +909,7 @@ class PresenceHandler(object):
user_ids = await self.state.get_current_users_in_room(room_id)
user_ids = list(filter(self.is_mine_id, user_ids))
- states = await self.current_state_for_users(user_ids)
+ states_d = await self.current_state_for_users(user_ids)
# Filter out old presence, i.e. offline presence states where
# the user hasn't been active for a week. We can change this
@@ -899,7 +919,7 @@ class PresenceHandler(object):
now = self.clock.time_msec()
states = [
state
- for state in states.values()
+ for state in states_d.values()
if state.state != PresenceState.OFFLINE
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
or state.status_msg is not None
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 50ce0c585b..302efc1b9a 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -141,8 +141,9 @@ class BaseProfileHandler(BaseHandler):
return result["displayname"]
- @defer.inlineCallbacks
- def set_displayname(self, target_user, requester, new_displayname, by_admin=False):
+ async def set_displayname(
+ self, target_user, requester, new_displayname, by_admin=False
+ ):
"""Set the displayname of a user
Args:
@@ -157,6 +158,15 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's displayname")
+ if not by_admin and not self.hs.config.enable_set_displayname:
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ if profile.display_name:
+ raise SynapseError(
+ 400,
+ "Changing display name is disabled on this server",
+ Codes.FORBIDDEN,
+ )
+
if len(new_displayname) > MAX_DISPLAYNAME_LEN:
raise SynapseError(
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
@@ -171,15 +181,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- yield self.store.set_profile_displayname(target_user.localpart, new_displayname)
+ await self.store.set_profile_displayname(target_user.localpart, new_displayname)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(target_user.localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
@@ -208,8 +218,9 @@ class BaseProfileHandler(BaseHandler):
return result["avatar_url"]
- @defer.inlineCallbacks
- def set_avatar_url(self, target_user, requester, new_avatar_url, by_admin=False):
+ async def set_avatar_url(
+ self, target_user, requester, new_avatar_url, by_admin=False
+ ):
"""target_user is the user whose avatar_url is to be changed;
auth_user is the user attempting to make this change."""
if not self.hs.is_mine(target_user):
@@ -218,6 +229,13 @@ class BaseProfileHandler(BaseHandler):
if not by_admin and target_user != requester.user:
raise AuthError(400, "Cannot set another user's avatar_url")
+ if not by_admin and not self.hs.config.enable_set_avatar_url:
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ if profile.avatar_url:
+ raise SynapseError(
+ 400, "Changing avatar is disabled on this server", Codes.FORBIDDEN
+ )
+
if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
raise SynapseError(
400, "Avatar URL is too long (max %i)" % (MAX_AVATAR_URL_LEN,)
@@ -227,15 +245,15 @@ class BaseProfileHandler(BaseHandler):
if by_admin:
requester = create_requester(target_user)
- yield self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
+ await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
if self.hs.config.user_directory_search_all_users:
- profile = yield self.store.get_profileinfo(target_user.localpart)
- yield self.user_directory_handler.handle_local_profile_change(
+ profile = await self.store.get_profileinfo(target_user.localpart)
+ await self.user_directory_handler.handle_local_profile_change(
target_user.to_string(), profile
)
- yield self._update_join_states(requester, target_user)
+ await self._update_join_states(requester, target_user)
@defer.inlineCallbacks
def on_profile_query(self, args):
@@ -263,21 +281,20 @@ class BaseProfileHandler(BaseHandler):
return response
- @defer.inlineCallbacks
- def _update_join_states(self, requester, target_user):
+ async def _update_join_states(self, requester, target_user):
if not self.hs.is_mine(target_user):
return
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
- room_ids = yield self.store.get_rooms_for_user(target_user.to_string())
+ room_ids = await self.store.get_rooms_for_user(target_user.to_string())
for room_id in room_ids:
handler = self.hs.get_room_member_handler()
try:
# Assume the target_user isn't a guest,
# because we don't let guests set profile or avatar data.
- yield handler.update_membership(
+ await handler.update_membership(
requester,
target_user,
room_id,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 7ffc194f0c..1e6bdac0ad 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -145,9 +145,9 @@ class RegistrationHandler(BaseHandler):
"""Registers a new client on the server.
Args:
- localpart : The local part of the user ID to register. If None,
+ localpart: The local part of the user ID to register. If None,
one will be generated.
- password (unicode) : The password to assign to this user so they can
+ password (unicode): The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
user_type (str|None): type of user. One of the values from
@@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler):
yield self.auth.check_auth_blocking(threepid=threepid)
password_hash = None
if password:
- password_hash = yield self._auth_handler.hash(password)
+ password_hash = yield defer.ensureDeferred(
+ self._auth_handler.hash(password)
+ )
if localpart is not None:
yield self.check_username(localpart, guest_access_token=guest_access_token)
@@ -242,7 +244,7 @@ class RegistrationHandler(BaseHandler):
fail_count += 1
if not self.hs.config.user_consent_at_registration:
- yield self._auto_join_rooms(user_id)
+ yield defer.ensureDeferred(self._auto_join_rooms(user_id))
else:
logger.info(
"Skipping auto-join for %s because consent is required at registration",
@@ -264,8 +266,7 @@ class RegistrationHandler(BaseHandler):
return user_id
- @defer.inlineCallbacks
- def _auto_join_rooms(self, user_id):
+ async def _auto_join_rooms(self, user_id):
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -279,9 +280,9 @@ class RegistrationHandler(BaseHandler):
# that an auto-generated support or bot user is not a real user and will never be
# the user to create the room
should_auto_create_rooms = False
- is_real_user = yield self.store.is_real_user(user_id)
+ is_real_user = await self.store.is_real_user(user_id)
if self.hs.config.autocreate_auto_join_rooms and is_real_user:
- count = yield self.store.count_real_users()
+ count = await self.store.count_real_users()
should_auto_create_rooms = count == 1
for r in self.hs.config.auto_join_rooms:
logger.info("Auto-joining %s to %s", user_id, r)
@@ -300,7 +301,7 @@ class RegistrationHandler(BaseHandler):
# getting the RoomCreationHandler during init gives a dependency
# loop
- yield self.hs.get_room_creation_handler().create_room(
+ await self.hs.get_room_creation_handler().create_room(
fake_requester,
config={
"preset": "public_chat",
@@ -309,7 +310,7 @@ class RegistrationHandler(BaseHandler):
ratelimit=False,
)
else:
- yield self._join_user_to_room(fake_requester, r)
+ await self._join_user_to_room(fake_requester, r)
except ConsentNotGivenError as e:
# Technically not necessary to pull out this error though
# moving away from bare excepts is a good thing to do.
@@ -317,15 +318,14 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- @defer.inlineCallbacks
- def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id):
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
user_id (str): The user to join
"""
- yield self._auto_join_rooms(user_id)
+ await self._auto_join_rooms(user_id)
@defer.inlineCallbacks
def appservice_register(self, user_localpart, as_token):
@@ -392,14 +392,13 @@ class RegistrationHandler(BaseHandler):
self._next_generated_user_id += 1
return str(id)
- @defer.inlineCallbacks
- def _join_user_to_room(self, requester, room_identifier):
+ async def _join_user_to_room(self, requester, room_identifier):
room_member_handler = self.hs.get_room_member_handler()
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
- room_id, remote_room_hosts = yield room_member_handler.lookup_room_alias(
+ room_id, remote_room_hosts = await room_member_handler.lookup_room_alias(
room_alias
)
room_id = room_id.to_string()
@@ -408,7 +407,7 @@ class RegistrationHandler(BaseHandler):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- yield room_member_handler.update_membership(
+ await room_member_handler.update_membership(
requester=requester,
target=requester.user,
room_id=room_id,
@@ -540,14 +539,15 @@ class RegistrationHandler(BaseHandler):
user_id, ["guest = true"]
)
else:
- access_token = yield self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ access_token = yield defer.ensureDeferred(
+ self._auth_handler.get_access_token_for_user_id(
+ user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ )
)
return (device_id, access_token)
- @defer.inlineCallbacks
- def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(self, user_id, auth_result, access_token):
"""A user has completed registration
Args:
@@ -558,7 +558,7 @@ class RegistrationHandler(BaseHandler):
device, or None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
- yield self._post_registration_client(
+ await self._post_registration_client(
user_id=user_id, auth_result=auth_result, access_token=access_token
)
return
@@ -570,19 +570,18 @@ class RegistrationHandler(BaseHandler):
if is_threepid_reserved(
self.hs.config.mau_limits_reserved_threepids, threepid
):
- yield self.store.upsert_monthly_active_user(user_id)
+ await self.store.upsert_monthly_active_user(user_id)
- yield self._register_email_threepid(user_id, threepid, access_token)
+ await self._register_email_threepid(user_id, threepid, access_token)
if auth_result and LoginType.MSISDN in auth_result:
threepid = auth_result[LoginType.MSISDN]
- yield self._register_msisdn_threepid(user_id, threepid)
+ await self._register_msisdn_threepid(user_id, threepid)
if auth_result and LoginType.TERMS in auth_result:
- yield self._on_user_consented(user_id, self.hs.config.user_consent_version)
+ await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- @defer.inlineCallbacks
- def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id, consent_version):
"""A user consented to the terms on registration
Args:
@@ -591,8 +590,8 @@ class RegistrationHandler(BaseHandler):
consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
- yield self.store.user_set_consent_version(user_id, consent_version)
- yield self.post_consent_actions(user_id)
+ await self.store.user_set_consent_version(user_id, consent_version)
+ await self.post_consent_actions(user_id)
@defer.inlineCallbacks
def _register_email_threepid(self, user_id, threepid, token):
@@ -617,8 +616,13 @@ class RegistrationHandler(BaseHandler):
logger.info("Can't add incomplete 3pid")
return
- yield self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ yield defer.ensureDeferred(
+ self._auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
)
# And we add an email pusher for them by default, but only
@@ -670,6 +674,11 @@ class RegistrationHandler(BaseHandler):
return None
raise
- yield self._auth_handler.add_threepid(
- user_id, threepid["medium"], threepid["address"], threepid["validated_at"]
+ yield defer.ensureDeferred(
+ self._auth_handler.add_threepid(
+ user_id,
+ threepid["medium"],
+ threepid["address"],
+ threepid["validated_at"],
+ )
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f580ab2e9f..73f9eeb399 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -25,8 +25,6 @@ from collections import OrderedDict
from six import iteritems, string_types
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -103,8 +101,7 @@ class RoomCreationHandler(BaseHandler):
self.third_party_event_rules = hs.get_third_party_event_rules()
- @defer.inlineCallbacks
- def upgrade_room(
+ async def upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
"""Replace a room with a new room with a different version
@@ -117,7 +114,7 @@ class RoomCreationHandler(BaseHandler):
Returns:
Deferred[unicode]: the new room id
"""
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
user_id = requester.user.to_string()
@@ -138,7 +135,7 @@ class RoomCreationHandler(BaseHandler):
# If this user has sent multiple upgrade requests for the same room
# and one of them is not complete yet, cache the response and
# return it to all subsequent requests
- ret = yield self._upgrade_response_cache.wrap(
+ ret = await self._upgrade_response_cache.wrap(
(old_room_id, user_id),
self._upgrade_room,
requester,
@@ -148,17 +145,16 @@ class RoomCreationHandler(BaseHandler):
return ret
- @defer.inlineCallbacks
- def _upgrade_room(
+ async def _upgrade_room(
self, requester: Requester, old_room_id: str, new_version: RoomVersion
):
user_id = requester.user.to_string()
# start by allocating a new room id
- r = yield self.store.get_room(old_room_id)
+ r = await self.store.get_room(old_room_id)
if r is None:
raise NotFoundError("Unknown room id %s" % (old_room_id,))
- new_room_id = yield self._generate_room_id(
+ new_room_id = await self._generate_room_id(
creator_id=user_id, is_public=r["is_public"], room_version=new_version,
)
@@ -169,7 +165,7 @@ class RoomCreationHandler(BaseHandler):
(
tombstone_event,
tombstone_context,
- ) = yield self.event_creation_handler.create_event(
+ ) = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Tombstone,
@@ -183,12 +179,12 @@ class RoomCreationHandler(BaseHandler):
},
token_id=requester.access_token_id,
)
- old_room_version = yield self.store.get_room_version_id(old_room_id)
- yield self.auth.check_from_context(
+ old_room_version = await self.store.get_room_version_id(old_room_id)
+ await self.auth.check_from_context(
old_room_version, tombstone_event, tombstone_context
)
- yield self.clone_existing_room(
+ await self.clone_existing_room(
requester,
old_room_id=old_room_id,
new_room_id=new_room_id,
@@ -197,32 +193,31 @@ class RoomCreationHandler(BaseHandler):
)
# now send the tombstone
- yield self.event_creation_handler.send_nonmember_event(
+ await self.event_creation_handler.send_nonmember_event(
requester, tombstone_event, tombstone_context
)
- old_room_state = yield tombstone_context.get_current_state_ids()
+ old_room_state = await tombstone_context.get_current_state_ids()
# update any aliases
- yield self._move_aliases_to_new_room(
+ await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
)
# Copy over user push rules, tags and migrate room directory state
- yield self.room_member_handler.transfer_room_state_on_room_upgrade(
+ await self.room_member_handler.transfer_room_state_on_room_upgrade(
old_room_id, new_room_id
)
# finally, shut down the PLs in the old room, and update them in the new
# room.
- yield self._update_upgraded_room_pls(
+ await self._update_upgraded_room_pls(
requester, old_room_id, new_room_id, old_room_state,
)
return new_room_id
- @defer.inlineCallbacks
- def _update_upgraded_room_pls(
+ async def _update_upgraded_room_pls(
self,
requester: Requester,
old_room_id: str,
@@ -249,7 +244,7 @@ class RoomCreationHandler(BaseHandler):
)
return
- old_room_pl_state = yield self.store.get_event(old_room_pl_event_id)
+ old_room_pl_state = await self.store.get_event(old_room_pl_event_id)
# we try to stop regular users from speaking by setting the PL required
# to send regular events and invites to 'Moderator' level. That's normally
@@ -278,7 +273,7 @@ class RoomCreationHandler(BaseHandler):
if updated:
try:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@@ -292,7 +287,7 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.PowerLevels,
@@ -304,8 +299,7 @@ class RoomCreationHandler(BaseHandler):
ratelimit=False,
)
- @defer.inlineCallbacks
- def clone_existing_room(
+ async def clone_existing_room(
self,
requester: Requester,
old_room_id: str,
@@ -338,7 +332,7 @@ class RoomCreationHandler(BaseHandler):
# Check if old room was non-federatable
# Get old room's create event
- old_room_create_event = yield self.store.get_create_event_for_room(old_room_id)
+ old_room_create_event = await self.store.get_create_event_for_room(old_room_id)
# Check if the create event specified a non-federatable room
if not old_room_create_event.content.get("m.federate", True):
@@ -361,11 +355,11 @@ class RoomCreationHandler(BaseHandler):
(EventTypes.PowerLevels, ""),
)
- old_room_state_ids = yield self.store.get_filtered_current_state_ids(
+ old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
# map from event_id to BaseEvent
- old_room_state_events = yield self.store.get_events(old_room_state_ids.values())
+ old_room_state_events = await self.store.get_events(old_room_state_ids.values())
for k, old_event_id in iteritems(old_room_state_ids):
old_event = old_room_state_events.get(old_event_id)
@@ -400,7 +394,7 @@ class RoomCreationHandler(BaseHandler):
if current_power_level < needed_power_level:
power_levels["users"][user_id] = needed_power_level
- yield self._send_events_for_new_room(
+ await self._send_events_for_new_room(
requester,
new_room_id,
# we expect to override all the presets with initial_state, so this is
@@ -412,12 +406,12 @@ class RoomCreationHandler(BaseHandler):
)
# Transfer membership events
- old_room_member_state_ids = yield self.store.get_filtered_current_state_ids(
+ old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
# map from event_id to BaseEvent
- old_room_member_state_events = yield self.store.get_events(
+ old_room_member_state_events = await self.store.get_events(
old_room_member_state_ids.values()
)
for k, old_event in iteritems(old_room_member_state_events):
@@ -426,7 +420,7 @@ class RoomCreationHandler(BaseHandler):
"membership" in old_event.content
and old_event.content["membership"] == "ban"
):
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester,
UserID.from_string(old_event["state_key"]),
new_room_id,
@@ -438,8 +432,7 @@ class RoomCreationHandler(BaseHandler):
# XXX invites/joins
# XXX 3pid invites
- @defer.inlineCallbacks
- def _move_aliases_to_new_room(
+ async def _move_aliases_to_new_room(
self,
requester: Requester,
old_room_id: str,
@@ -448,13 +441,13 @@ class RoomCreationHandler(BaseHandler):
):
directory_handler = self.hs.get_handlers().directory_handler
- aliases = yield self.store.get_aliases_for_room(old_room_id)
+ aliases = await self.store.get_aliases_for_room(old_room_id)
# check to see if we have a canonical alias.
canonical_alias_event = None
canonical_alias_event_id = old_room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event_id:
- canonical_alias_event = yield self.store.get_event(canonical_alias_event_id)
+ canonical_alias_event = await self.store.get_event(canonical_alias_event_id)
# first we try to remove the aliases from the old room (we suppress sending
# the room_aliases event until the end).
@@ -472,7 +465,7 @@ class RoomCreationHandler(BaseHandler):
for alias_str in aliases:
alias = RoomAlias.from_string(alias_str)
try:
- yield directory_handler.delete_association(requester, alias)
+ await directory_handler.delete_association(requester, alias)
removed_aliases.append(alias_str)
except SynapseError as e:
logger.warning("Unable to remove alias %s from old room: %s", alias, e)
@@ -485,7 +478,7 @@ class RoomCreationHandler(BaseHandler):
# we can now add any aliases we successfully removed to the new room.
for alias in removed_aliases:
try:
- yield directory_handler.create_association(
+ await directory_handler.create_association(
requester,
RoomAlias.from_string(alias),
new_room_id,
@@ -502,7 +495,7 @@ class RoomCreationHandler(BaseHandler):
# alias event for the new room with a copy of the information.
try:
if canonical_alias_event:
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.CanonicalAlias,
@@ -518,8 +511,9 @@ class RoomCreationHandler(BaseHandler):
# we returned the new room to the client at this point.
logger.error("Unable to send updated alias events in new room: %s", e)
- @defer.inlineCallbacks
- def create_room(self, requester, config, ratelimit=True, creator_join_profile=None):
+ async def create_room(
+ self, requester, config, ratelimit=True, creator_join_profile=None
+ ):
""" Creates a new room.
Args:
@@ -547,7 +541,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- yield self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(user_id)
if (
self._server_notices_mxid is not None
@@ -556,11 +550,11 @@ class RoomCreationHandler(BaseHandler):
# allow the server notices mxid to create rooms
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
# Check whether the third party rules allows/changes the room create
# request.
- event_allowed = yield self.third_party_event_rules.on_create_room(
+ event_allowed = await self.third_party_event_rules.on_create_room(
requester, config, is_requester_admin=is_requester_admin
)
if not event_allowed:
@@ -574,7 +568,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(403, "You are not permitted to create rooms")
if ratelimit:
- yield self.ratelimit(requester)
+ await self.ratelimit(requester)
room_version_id = config.get(
"room_version", self.config.default_room_version.identifier
@@ -597,7 +591,7 @@ class RoomCreationHandler(BaseHandler):
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias(config["room_alias_name"], self.hs.hostname)
- mapping = yield self.store.get_association_from_room_alias(room_alias)
+ mapping = await self.store.get_association_from_room_alias(room_alias)
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
@@ -612,7 +606,7 @@ class RoomCreationHandler(BaseHandler):
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))
- yield self.event_creation_handler.assert_accepted_privacy_policy(requester)
+ await self.event_creation_handler.assert_accepted_privacy_policy(requester)
power_level_content_override = config.get("power_level_content_override")
if (
@@ -631,13 +625,13 @@ class RoomCreationHandler(BaseHandler):
visibility = config.get("visibility", None)
is_public = visibility == "public"
- room_id = yield self._generate_room_id(
+ room_id = await self._generate_room_id(
creator_id=user_id, is_public=is_public, room_version=room_version,
)
directory_handler = self.hs.get_handlers().directory_handler
if room_alias:
- yield directory_handler.create_association(
+ await directory_handler.create_association(
requester=requester,
room_id=room_id,
room_alias=room_alias,
@@ -645,6 +639,13 @@ class RoomCreationHandler(BaseHandler):
check_membership=False,
)
+ if is_public:
+ if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias):
+ # Lets just return a generic message, as there may be all sorts of
+ # reasons why we said no. TODO: Allow configurable error messages
+ # per alias creation rule?
+ raise SynapseError(403, "Not allowed to publish room")
+
preset_config = config.get(
"preset",
RoomCreationPreset.PRIVATE_CHAT
@@ -663,7 +664,7 @@ class RoomCreationHandler(BaseHandler):
# override any attempt to set room versions via the creation_content
creation_content["room_version"] = room_version.identifier
- yield self._send_events_for_new_room(
+ await self._send_events_for_new_room(
requester,
room_id,
preset_config=preset_config,
@@ -677,7 +678,7 @@ class RoomCreationHandler(BaseHandler):
if "name" in config:
name = config["name"]
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Name,
@@ -691,7 +692,7 @@ class RoomCreationHandler(BaseHandler):
if "topic" in config:
topic = config["topic"]
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.Topic,
@@ -709,7 +710,7 @@ class RoomCreationHandler(BaseHandler):
if is_direct:
content["is_direct"] = is_direct
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
@@ -723,7 +724,7 @@ class RoomCreationHandler(BaseHandler):
id_access_token = invite_3pid.get("id_access_token") # optional
address = invite_3pid["address"]
medium = invite_3pid["medium"]
- yield self.hs.get_room_member_handler().do_3pid_invite(
+ await self.hs.get_room_member_handler().do_3pid_invite(
room_id,
requester.user,
medium,
@@ -741,8 +742,7 @@ class RoomCreationHandler(BaseHandler):
return result
- @defer.inlineCallbacks
- def _send_events_for_new_room(
+ async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
@@ -762,11 +762,10 @@ class RoomCreationHandler(BaseHandler):
return e
- @defer.inlineCallbacks
- def send(etype, content, **kwargs):
+ async def send(etype, content, **kwargs):
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
creator, event, ratelimit=False
)
@@ -777,10 +776,10 @@ class RoomCreationHandler(BaseHandler):
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
- yield send(etype=EventTypes.Create, content=creation_content)
+ await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
- yield self.room_member_handler.update_membership(
+ await self.room_member_handler.update_membership(
creator,
creator.user,
room_id,
@@ -793,7 +792,7 @@ class RoomCreationHandler(BaseHandler):
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
- yield send(etype=EventTypes.PowerLevels, content=pl_content)
+ await send(etype=EventTypes.PowerLevels, content=pl_content)
else:
power_level_content = {
"users": {creator_id: 100},
@@ -806,6 +805,7 @@ class RoomCreationHandler(BaseHandler):
EventTypes.RoomAvatar: 50,
EventTypes.Tombstone: 100,
EventTypes.ServerACL: 100,
+ EventTypes.RoomEncryption: 100,
},
"events_default": 0,
"state_default": 50,
@@ -825,36 +825,35 @@ class RoomCreationHandler(BaseHandler):
if power_level_content_override:
power_level_content.update(power_level_content_override)
- yield send(etype=EventTypes.PowerLevels, content=power_level_content)
+ await send(etype=EventTypes.PowerLevels, content=power_level_content)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
if (EventTypes.JoinRules, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
- yield send(
+ await send(
etype=EventTypes.GuestAccess, content={"guest_access": "can_join"}
)
for (etype, state_key), content in initial_state.items():
- yield send(etype=etype, state_key=state_key, content=content)
+ await send(etype=etype, state_key=state_key, content=content)
- @defer.inlineCallbacks
- def _generate_room_id(
+ async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
@@ -866,7 +865,7 @@ class RoomCreationHandler(BaseHandler):
gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
if isinstance(gen_room_id, bytes):
gen_room_id = gen_room_id.decode("utf-8")
- yield self.store.store_room(
+ await self.store.store_room(
room_id=gen_room_id,
room_creator_user_id=creator_id,
is_public=is_public,
@@ -885,8 +884,7 @@ class RoomContextHandler(object):
self.storage = hs.get_storage()
self.state_store = self.storage.state
- @defer.inlineCallbacks
- def get_event_context(self, user, room_id, event_id, limit, event_filter):
+ async def get_event_context(self, user, room_id, event_id, limit, event_filter):
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -905,7 +903,7 @@ class RoomContextHandler(object):
before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit
- users = yield self.store.get_users_in_room(room_id)
+ users = await self.store.get_users_in_room(room_id)
is_peeking = user.to_string() not in users
def filter_evts(events):
@@ -913,17 +911,17 @@ class RoomContextHandler(object):
self.storage, user.to_string(), events, is_peeking=is_peeking
)
- event = yield self.store.get_event(
+ event = await self.store.get_event(
event_id, get_prev_content=True, allow_none=True
)
if not event:
return None
- filtered = yield (filter_evts([event]))
+ filtered = await filter_evts([event])
if not filtered:
raise AuthError(403, "You don't have permission to access that event.")
- results = yield self.store.get_events_around(
+ results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
@@ -931,8 +929,8 @@ class RoomContextHandler(object):
results["events_before"] = event_filter.filter(results["events_before"])
results["events_after"] = event_filter.filter(results["events_after"])
- results["events_before"] = yield filter_evts(results["events_before"])
- results["events_after"] = yield filter_evts(results["events_after"])
+ results["events_before"] = await filter_evts(results["events_before"])
+ results["events_after"] = await filter_evts(results["events_after"])
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
@@ -959,7 +957,7 @@ class RoomContextHandler(object):
# first? Shouldn't we be consistent with /sync?
# https://github.com/matrix-org/matrix-doc/issues/687
- state = yield self.state_store.get_state_for_events(
+ state = await self.state_store.get_state_for_events(
[last_event_id], state_filter=state_filter
)
@@ -967,7 +965,7 @@ class RoomContextHandler(object):
if event_filter:
state_events = event_filter.filter(state_events)
- results["state"] = yield filter_evts(state_events)
+ results["state"] = await filter_evts(state_events)
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
@@ -986,13 +984,12 @@ class RoomEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_new_events(
+ async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
# We just ignore the key for now.
- to_key = yield self.get_current_key()
+ to_key = await self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
@@ -1005,11 +1002,11 @@ class RoomEventSource(object):
# See https://github.com/matrix-org/matrix-doc/issues/1144
raise NotImplementedError()
else:
- room_events = yield self.store.get_membership_changes_for_user(
+ room_events = await self.store.get_membership_changes_for_user(
user.to_string(), from_key, to_key
)
- room_to_events = yield self.store.get_room_events_stream_for_rooms(
+ room_to_events = await self.store.get_room_events_stream_for_rooms(
room_ids=room_ids,
from_key=from_key,
to_key=to_key,
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 0b7d3da680..e75dabcd77 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import Any, Dict, Optional
from six import iteritems
@@ -89,7 +90,11 @@ class RoomListHandler(BaseHandler):
logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
- limit, since_token, search_filter, network_tuple=network_tuple
+ limit,
+ since_token,
+ search_filter,
+ network_tuple=network_tuple,
+ from_federation=from_federation,
)
key = (limit, since_token, network_tuple)
@@ -105,22 +110,22 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def _get_public_room_list(
self,
- limit=None,
- since_token=None,
- search_filter=None,
- network_tuple=EMPTY_THIRD_PARTY_ID,
- from_federation=False,
- ):
+ limit: Optional[int] = None,
+ since_token: Optional[str] = None,
+ search_filter: Optional[Dict] = None,
+ network_tuple: ThirdPartyInstanceID = EMPTY_THIRD_PARTY_ID,
+ from_federation: bool = False,
+ ) -> Dict[str, Any]:
"""Generate a public room list.
Args:
- limit (int|None): Maximum amount of rooms to return.
- since_token (str|None)
- search_filter (dict|None): Dictionary to filter rooms by.
- network_tuple (ThirdPartyInstanceID): Which public list to use.
+ limit: Maximum amount of rooms to return.
+ since_token:
+ search_filter: Dictionary to filter rooms by.
+ network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation (bool): Whether this request originated from a
+ from_federation: Whether this request originated from a
federating server or a client. Used for room filtering.
"""
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4260426369..ccc9659454 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -142,8 +142,7 @@ class RoomMemberHandler(object):
"""
raise NotImplementedError()
- @defer.inlineCallbacks
- def _local_membership_update(
+ async def _local_membership_update(
self,
requester,
target,
@@ -164,7 +163,7 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- event, context = yield self.event_creation_handler.create_event(
+ event, context = await self.event_creation_handler.create_event(
requester,
{
"type": EventTypes.Member,
@@ -182,18 +181,18 @@ class RoomMemberHandler(object):
)
# Check if this event matches the previous membership event for the user.
- duplicate = yield self.event_creation_handler.deduplicate_state_event(
+ duplicate = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return duplicate
- yield self.event_creation_handler.handle_new_client_event(
+ await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target], ratelimit=ratelimit
)
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -203,15 +202,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target, room_id)
+ await self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- yield self._user_left_room(target, room_id)
+ await self._user_left_room(target, room_id)
return event
@@ -253,8 +252,7 @@ class RoomMemberHandler(object):
for tag, tag_content in room_tags.items():
yield self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content)
- @defer.inlineCallbacks
- def update_membership(
+ async def update_membership(
self,
requester,
target,
@@ -269,8 +267,8 @@ class RoomMemberHandler(object):
):
key = (room_id,)
- with (yield self.member_linearizer.queue(key)):
- result = yield self._update_membership(
+ with (await self.member_linearizer.queue(key)):
+ result = await self._update_membership(
requester,
target,
room_id,
@@ -285,8 +283,7 @@ class RoomMemberHandler(object):
return result
- @defer.inlineCallbacks
- def _update_membership(
+ async def _update_membership(
self,
requester,
target,
@@ -321,7 +318,7 @@ class RoomMemberHandler(object):
# if this is a join with a 3pid signature, we may need to turn a 3pid
# invite into a normal invite before we can handle the join.
if third_party_signed is not None:
- yield self.federation_handler.exchange_third_party_invite(
+ await self.federation_handler.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
@@ -332,7 +329,7 @@ class RoomMemberHandler(object):
remote_room_hosts = []
if effective_membership_state not in ("leave", "ban"):
- is_blocked = yield self.store.is_room_blocked(room_id)
+ is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
@@ -351,7 +348,7 @@ class RoomMemberHandler(object):
is_requester_admin = True
else:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
if self.config.block_non_admin_invites:
@@ -370,9 +367,9 @@ class RoomMemberHandler(object):
if block_invite:
raise SynapseError(403, "Invites have been disabled on this server")
- latest_event_ids = yield self.store.get_prev_events_for_room(room_id)
+ latest_event_ids = await self.store.get_prev_events_for_room(room_id)
- current_state_ids = yield self.state_handler.get_current_state_ids(
+ current_state_ids = await self.state_handler.get_current_state_ids(
room_id, latest_event_ids=latest_event_ids
)
@@ -381,7 +378,7 @@ class RoomMemberHandler(object):
# transitions and generic otherwise
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
if old_state_id:
- old_state = yield self.store.get_event(old_state_id, allow_none=True)
+ old_state = await self.store.get_event(old_state_id, allow_none=True)
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
@@ -413,7 +410,7 @@ class RoomMemberHandler(object):
old_membership == Membership.INVITE
and effective_membership_state == Membership.LEAVE
):
- is_blocked = yield self._is_server_notice_room(room_id)
+ is_blocked = await self._is_server_notice_room(room_id)
if is_blocked:
raise SynapseError(
http_client.FORBIDDEN,
@@ -424,18 +421,18 @@ class RoomMemberHandler(object):
if action == "kick":
raise AuthError(403, "The target user is not in the room")
- is_host_in_room = yield self._is_host_in_room(current_state_ids)
+ is_host_in_room = await self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(current_state_ids)
+ guest_can_join = await self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- inviter = yield self._get_inviter(target.to_string(), room_id)
+ inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
@@ -443,13 +440,13 @@ class RoomMemberHandler(object):
profile = self.profile_handler
if not content_specified:
- content["displayname"] = yield profile.get_displayname(target)
- content["avatar_url"] = yield profile.get_avatar_url(target)
+ content["displayname"] = await profile.get_displayname(target)
+ content["avatar_url"] = await profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
- remote_join_response = yield self._remote_join(
+ remote_join_response = await self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
@@ -458,7 +455,7 @@ class RoomMemberHandler(object):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- inviter = yield self._get_inviter(target.to_string(), room_id)
+ inviter = await self._get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
@@ -472,12 +469,12 @@ class RoomMemberHandler(object):
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
- res = yield self._remote_reject_invite(
+ res = await self._remote_reject_invite(
requester, remote_room_hosts, room_id, target, content,
)
return res
- res = yield self._local_membership_update(
+ res = await self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@@ -519,6 +516,9 @@ class RoomMemberHandler(object):
yield self.store.set_room_is_public(old_room_id, False)
yield self.store.set_room_is_public(room_id, True)
+ # Transfer alias mappings in the room directory
+ yield self.store.update_aliases_for_room(old_room_id, room_id)
+
# Check if any groups we own contain the predecessor room
local_group_ids = yield self.store.get_local_groups_for_room(old_room_id)
for group_id in local_group_ids:
@@ -569,8 +569,7 @@ class RoomMemberHandler(object):
)
continue
- @defer.inlineCallbacks
- def send_membership_event(self, requester, event, context, ratelimit=True):
+ async def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
@@ -596,27 +595,27 @@ class RoomMemberHandler(object):
else:
requester = types.create_requester(target_user)
- prev_event = yield self.event_creation_handler.deduplicate_state_event(
+ prev_event = await self.event_creation_handler.deduplicate_state_event(
event, context
)
if prev_event is not None:
return
- prev_state_ids = yield context.get_prev_state_ids()
+ prev_state_ids = await context.get_prev_state_ids()
if event.membership == Membership.JOIN:
if requester.is_guest:
- guest_can_join = yield self._can_guest_join(prev_state_ids)
+ guest_can_join = await self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if event.membership not in (Membership.LEAVE, Membership.BAN):
- is_blocked = yield self.store.is_room_blocked(room_id)
+ is_blocked = await self.store.is_room_blocked(room_id)
if is_blocked:
raise SynapseError(403, "This room has been blocked on this server")
- yield self.event_creation_handler.handle_new_client_event(
+ await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
)
@@ -630,15 +629,15 @@ class RoomMemberHandler(object):
# info.
newly_joined = True
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target_user, room_id)
+ await self._user_joined_room(target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
- prev_member_event = yield self.store.get_event(prev_member_event_id)
+ prev_member_event = await self.store.get_event(prev_member_event_id)
if prev_member_event.membership == Membership.JOIN:
- yield self._user_left_room(target_user, room_id)
+ await self._user_left_room(target_user, room_id)
@defer.inlineCallbacks
def _can_guest_join(self, current_state_ids):
@@ -696,8 +695,7 @@ class RoomMemberHandler(object):
if invite:
return UserID.from_string(invite.sender)
- @defer.inlineCallbacks
- def do_3pid_invite(
+ async def do_3pid_invite(
self,
room_id,
inviter,
@@ -709,7 +707,7 @@ class RoomMemberHandler(object):
id_access_token=None,
):
if self.config.block_non_admin_invites:
- is_requester_admin = yield self.auth.is_server_admin(requester.user)
+ is_requester_admin = await self.auth.is_server_admin(requester.user)
if not is_requester_admin:
raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -717,9 +715,9 @@ class RoomMemberHandler(object):
# We need to rate limit *before* we send out any 3PID invites, so we
# can't just rely on the standard ratelimiting of events.
- yield self.base_handler.ratelimit(requester)
+ await self.base_handler.ratelimit(requester)
- can_invite = yield self.third_party_event_rules.check_threepid_can_be_invited(
+ can_invite = await self.third_party_event_rules.check_threepid_can_be_invited(
medium, address, room_id
)
if not can_invite:
@@ -734,16 +732,16 @@ class RoomMemberHandler(object):
403, "Looking up third-party identifiers is denied from this server"
)
- invitee = yield self.identity_handler.lookup_3pid(
+ invitee = await self.identity_handler.lookup_3pid(
id_server, medium, address, id_access_token
)
if invitee:
- yield self.update_membership(
+ await self.update_membership(
requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
)
else:
- yield self._make_and_store_3pid_invite(
+ await self._make_and_store_3pid_invite(
requester,
id_server,
medium,
@@ -754,8 +752,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
- @defer.inlineCallbacks
- def _make_and_store_3pid_invite(
+ async def _make_and_store_3pid_invite(
self,
requester,
id_server,
@@ -766,7 +763,7 @@ class RoomMemberHandler(object):
txn_id,
id_access_token=None,
):
- room_state = yield self.state_handler.get_current_state(room_id)
+ room_state = await self.state_handler.get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -804,7 +801,7 @@ class RoomMemberHandler(object):
public_keys,
fallback_public_key,
display_name,
- ) = yield self.identity_handler.ask_id_server_for_third_party_invite(
+ ) = await self.identity_handler.ask_id_server_for_third_party_invite(
requester=requester,
id_server=id_server,
medium=medium,
@@ -820,7 +817,7 @@ class RoomMemberHandler(object):
id_access_token=id_access_token,
)
- yield self.event_creation_handler.create_and_send_nonmember_event(
+ await self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
@@ -878,8 +875,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
- @defer.inlineCallbacks
- def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
+ async def _is_remote_room_too_complex(self, room_id, remote_room_hosts):
"""
Check if complexity of a remote room is too great.
@@ -891,7 +887,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if unable to be fetched
"""
max_complexity = self.hs.config.limit_remote_rooms.complexity
- complexity = yield self.federation_handler.get_room_complexity(
+ complexity = await self.federation_handler.get_room_complexity(
remote_room_hosts, room_id
)
@@ -914,8 +910,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return complexity["v1"] > max_complexity
- @defer.inlineCallbacks
- def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
+ async def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
"""Implements RoomMemberHandler._remote_join
"""
# filter ourselves out of remote_room_hosts: do_invite_join ignores it
@@ -930,7 +925,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if self.hs.config.limit_remote_rooms.enabled:
# Fetch the room complexity
- too_complex = yield self._is_remote_room_too_complex(
+ too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
)
if too_complex is True:
@@ -944,12 +939,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
- yield defer.ensureDeferred(
- self.federation_handler.do_invite_join(
- remote_room_hosts, room_id, user.to_string(), content
- )
+ await self.federation_handler.do_invite_join(
+ remote_room_hosts, room_id, user.to_string(), content
)
- yield self._user_joined_room(user, room_id)
+ await self._user_joined_room(user, room_id)
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
@@ -959,7 +952,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return
# Check again, but with the local state events
- too_complex = yield self._is_local_room_too_complex(room_id)
+ too_complex = await self._is_local_room_too_complex(room_id)
if too_complex is False:
# We're under the limit.
@@ -967,7 +960,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# The room is too large. Leave.
requester = types.create_requester(user, None, False, None)
- yield self.update_membership(
+ await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
raise SynapseError(
@@ -1005,12 +998,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
def _user_joined_room(self, target, room_id):
"""Implements RoomMemberHandler._user_joined_room
"""
- return user_joined_room(self.distributor, target, room_id)
+ return defer.succeed(user_joined_room(self.distributor, target, room_id))
def _user_left_room(self, target, room_id):
"""Implements RoomMemberHandler._user_left_room
"""
- return user_left_room(self.distributor, target, room_id)
+ return defer.succeed(user_left_room(self.distributor, target, room_id))
@defer.inlineCallbacks
def forget(self, user, room_id):
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 72c109981b..e7015c704f 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import re
-from typing import Tuple
+from typing import Callable, Dict, Optional, Set, Tuple
import attr
import saml2
@@ -25,7 +25,9 @@ from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
+from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
+from synapse.module_api.errors import RedirectException
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -43,11 +45,15 @@ class Saml2SessionData:
# time the session was created, in milliseconds
creation_time = attr.ib()
+ # The user interactive authentication session ID associated with this SAML
+ # session (or None if this SAML session is for an initial login).
+ ui_auth_session_id = attr.ib(type=Optional[str], default=None)
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
+ self._auth = hs.get_auth()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
@@ -76,22 +82,28 @@ class SamlHandler:
self._error_html_content = hs.config.saml2_error_html_content
- def handle_redirect_request(self, client_redirect_url):
+ def handle_redirect_request(
+ self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
+ ) -> bytes:
"""Handle an incoming request to /login/sso/redirect
Args:
- client_redirect_url (bytes): the URL that we should redirect the
+ client_redirect_url: the URL that we should redirect the
client to when everything is done
+ ui_auth_session_id: The session ID of the ongoing UI Auth (or
+ None if this is a login).
Returns:
- bytes: URL to redirect to
+ URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
relay_state=client_redirect_url
)
now = self._clock.time_msec()
- self._outstanding_requests_dict[reqid] = Saml2SessionData(creation_time=now)
+ self._outstanding_requests_dict[reqid] = Saml2SessionData(
+ creation_time=now, ui_auth_session_id=ui_auth_session_id,
+ )
for key, value in info["headers"]:
if key == "Location":
@@ -100,15 +112,15 @@ class SamlHandler:
# this shouldn't happen!
raise Exception("prepare_for_authenticate didn't return a Location header")
- async def handle_saml_response(self, request):
+ async def handle_saml_response(self, request: SynapseRequest) -> None:
"""Handle an incoming request to /_matrix/saml2/authn_response
Args:
- request (SynapseRequest): the incoming request from the browser. We'll
+ request: the incoming request from the browser. We'll
respond to it with a redirect.
Returns:
- Deferred[none]: Completes once we have handled the request.
+ Completes once we have handled the request.
"""
resp_bytes = parse_string(request, "SAMLResponse", required=True)
relay_state = parse_string(request, "RelayState", required=True)
@@ -118,7 +130,12 @@ class SamlHandler:
self.expire_sessions()
try:
- user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+ user_id, current_session = await self._map_saml_response_to_user(
+ resp_bytes, relay_state
+ )
+ except RedirectException:
+ # Raise the exception as per the wishes of the SAML module response
+ raise
except Exception as e:
# If decoding the response or mapping it to a user failed, then log the
# error and tell the user that something went wrong.
@@ -133,9 +150,28 @@ class SamlHandler:
finish_request(request)
return
- self._auth_handler.complete_sso_login(user_id, request, relay_state)
+ # Complete the interactive auth session or the login.
+ if current_session and current_session.ui_auth_session_id:
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, current_session.ui_auth_session_id, request
+ )
+
+ else:
+ await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+
+ async def _map_saml_response_to_user(
+ self, resp_bytes: str, client_redirect_url: str
+ ) -> Tuple[str, Optional[Saml2SessionData]]:
+ """
+ Given a sample response, retrieve the cached session and user for it.
- async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
+ Args:
+ resp_bytes: The SAML response.
+ client_redirect_url: The redirect URL passed in by the client.
+
+ Returns:
+ Tuple of the user ID and SAML session associated with this response.
+ """
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -163,7 +199,9 @@ class SamlHandler:
logger.info("SAML2 mapped attributes: %s", saml2_auth.ava)
- self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
+ current_session = self._outstanding_requests_dict.pop(
+ saml2_auth.in_response_to, None
+ )
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
@@ -184,7 +222,7 @@ class SamlHandler:
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ return registered_user_id, current_session
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@@ -209,7 +247,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
- return registered_user_id
+ return registered_user_id, current_session
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
@@ -256,7 +294,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
- return registered_user_id
+ return registered_user_id, current_session
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -275,6 +313,7 @@ DOT_REPLACE_PATTERN = re.compile(
def dot_replace_for_mxid(username: str) -> str:
+ """Replace any characters which are not allowed in Matrix IDs with a dot."""
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
@@ -286,7 +325,7 @@ def dot_replace_for_mxid(username: str) -> str:
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
-}
+} # type: Dict[str, Callable[[str], str]]
@attr.s
@@ -314,7 +353,7 @@ class DefaultSamlMappingProvider(object):
def get_remote_user_id(
self, saml_response: saml2.response.AuthnResponse, client_redirect_url: str
- ):
+ ) -> str:
"""Extracts the remote user id from the SAML response"""
try:
return saml_response.ava["uid"][0]
@@ -393,14 +432,14 @@ class DefaultSamlMappingProvider(object):
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
- def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
+ def get_saml_attributes(config: SamlConfig) -> Tuple[Set[str], Set[str]]:
"""Returns the required attributes of a SAML
Args:
config: A SamlConfig object containing configuration params for this provider
Returns:
- tuple[set,set]: The first set equates to the saml auth response
+ The first set equates to the saml auth response
attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if
available, but are not necessary
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index ec1542d416..4d40d3ac9c 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -18,8 +18,6 @@ import logging
from unpaddedbase64 import decode_base64, encode_base64
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
@@ -39,8 +37,7 @@ class SearchHandler(BaseHandler):
self.state_store = self.storage.state
self.auth = hs.get_auth()
- @defer.inlineCallbacks
- def get_old_rooms_from_upgraded_room(self, room_id):
+ async def get_old_rooms_from_upgraded_room(self, room_id):
"""Retrieves room IDs of old rooms in the history of an upgraded room.
We do so by checking the m.room.create event of the room for a
@@ -60,7 +57,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
# The initial room must have been known for us to get this far
- predecessor = yield self.store.get_room_predecessor(room_id)
+ predecessor = await self.store.get_room_predecessor(room_id)
while True:
if not predecessor:
@@ -75,7 +72,7 @@ class SearchHandler(BaseHandler):
# Don't add it to the list until we have checked that we are in the room
try:
- next_predecessor_room = yield self.store.get_room_predecessor(
+ next_predecessor_room = await self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
@@ -89,8 +86,7 @@ class SearchHandler(BaseHandler):
return historical_room_ids
- @defer.inlineCallbacks
- def search(self, user, content, batch=None):
+ async def search(self, user, content, batch=None):
"""Performs a full text search for a user.
Args:
@@ -179,7 +175,7 @@ class SearchHandler(BaseHandler):
search_filter = Filter(filter_dict)
# TODO: Search through left rooms too
- rooms = yield self.store.get_rooms_for_local_user_where_membership_is(
+ rooms = await self.store.get_rooms_for_local_user_where_membership_is(
user.to_string(),
membership_list=[Membership.JOIN],
# membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
@@ -192,7 +188,7 @@ class SearchHandler(BaseHandler):
historical_room_ids = []
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
- ids = yield self.get_old_rooms_from_upgraded_room(room_id)
+ ids = await self.get_old_rooms_from_upgraded_room(room_id)
historical_room_ids += ids
# Prevent any historical events from being filtered
@@ -223,7 +219,7 @@ class SearchHandler(BaseHandler):
count = None
if order_by == "rank":
- search_result = yield self.store.search_msgs(room_ids, search_term, keys)
+ search_result = await self.store.search_msgs(room_ids, search_term, keys)
count = search_result["count"]
@@ -238,7 +234,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -267,7 +263,7 @@ class SearchHandler(BaseHandler):
# But only go around 5 times since otherwise synapse will be sad.
while len(room_events) < search_filter.limit() and i < 5:
i += 1
- search_result = yield self.store.search_rooms(
+ search_result = await self.store.search_rooms(
room_ids,
search_term,
keys,
@@ -288,7 +284,7 @@ class SearchHandler(BaseHandler):
filtered_events = search_filter.filter([r["event"] for r in results])
- events = yield filter_events_for_client(
+ events = await filter_events_for_client(
self.storage, user.to_string(), filtered_events
)
@@ -343,11 +339,11 @@ class SearchHandler(BaseHandler):
# If client has asked for "context" for each event (i.e. some surrounding
# events and state), fetch that
if event_context is not None:
- now_token = yield self.hs.get_event_sources().get_current_token()
+ now_token = await self.hs.get_event_sources().get_current_token()
contexts = {}
for event in allowed_events:
- res = yield self.store.get_events_around(
+ res = await self.store.get_events_around(
event.room_id, event.event_id, before_limit, after_limit
)
@@ -357,11 +353,11 @@ class SearchHandler(BaseHandler):
len(res["events_after"]),
)
- res["events_before"] = yield filter_events_for_client(
+ res["events_before"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_before"]
)
- res["events_after"] = yield filter_events_for_client(
+ res["events_after"] = await filter_events_for_client(
self.storage, user.to_string(), res["events_after"]
)
@@ -390,7 +386,7 @@ class SearchHandler(BaseHandler):
[(EventTypes.Member, sender) for sender in senders]
)
- state = yield self.state_store.get_state_for_event(
+ state = await self.state_store.get_state_for_event(
last_event_id, state_filter
)
@@ -412,10 +408,10 @@ class SearchHandler(BaseHandler):
time_now = self.clock.time_msec()
for context in contexts.values():
- context["events_before"] = yield self._event_serializer.serialize_events(
+ context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"], time_now
)
- context["events_after"] = yield self._event_serializer.serialize_events(
+ context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"], time_now
)
@@ -423,7 +419,7 @@ class SearchHandler(BaseHandler):
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
- state = yield self.state_handler.get_current_state(room_id)
+ state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
@@ -437,7 +433,7 @@ class SearchHandler(BaseHandler):
{
"rank": rank_map[e.event_id],
"result": (
- yield self._event_serializer.serialize_event(e, time_now)
+ await self._event_serializer.serialize_event(e, time_now)
),
"context": contexts.get(e.event_id, {}),
}
@@ -452,7 +448,7 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
- s[room_id] = yield self._event_serializer.serialize_events(
+ s[room_id] = await self._event_serializer.serialize_events(
state, time_now
)
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 12657ca698..63d8f9aa0d 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,8 +15,6 @@
import logging
from typing import Optional
-from twisted.internet import defer
-
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
@@ -32,9 +30,9 @@ class SetPasswordHandler(BaseHandler):
super(SetPasswordHandler, self).__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
+ self._password_policy_handler = hs.get_password_policy_handler()
- @defer.inlineCallbacks
- def set_password(
+ async def set_password(
self,
user_id: str,
new_password: str,
@@ -44,10 +42,11 @@ class SetPasswordHandler(BaseHandler):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
- password_hash = yield self._auth_handler.hash(new_password)
+ self._password_policy_handler.validate_password(new_password)
+ password_hash = await self._auth_handler.hash(new_password)
try:
- yield self.store.user_set_password_hash(user_id, password_hash)
+ await self.store.user_set_password_hash(user_id, password_hash)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
@@ -59,12 +58,12 @@ class SetPasswordHandler(BaseHandler):
except_access_token_id = requester.access_token_id if requester else None
# First delete all of their other devices.
- yield self._device_handler.delete_all_devices_for_user(
+ await self._device_handler.delete_all_devices_for_user(
user_id, except_device_id=except_device_id
)
# and now delete any access tokens which weren't associated with
# devices (or were associated with this device).
- yield self._auth_handler.delete_access_tokens_for_user(
+ await self._auth_handler.delete_access_tokens_for_user(
user_id, except_token_id=except_access_token_id
)
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index cfd5dfc9e5..00718d7f2d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -26,7 +26,7 @@ from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.filtering import FilterCollection
from synapse.events import EventBase
-from synapse.logging.context import LoggingContext
+from synapse.logging.context import current_context
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
@@ -301,7 +301,7 @@ class SyncHandler(object):
else:
sync_type = "incremental_sync"
- context = LoggingContext.current_context()
+ context = current_context()
if context:
context.tag = sync_type
@@ -1143,10 +1143,14 @@ class SyncHandler(object):
user_id
)
- tracked_users = set(users_who_share_room)
+ # Always tell the user about their own devices. We check as the user
+ # ID is almost certainly already included (unless they're not in any
+ # rooms) and taking a copy of the set is relatively expensive.
+ if user_id not in users_who_share_room:
+ users_who_share_room = set(users_who_share_room)
+ users_who_share_room.add(user_id)
- # Always tell the user about their own devices
- tracked_users.add(user_id)
+ tracked_users = users_who_share_room
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
@@ -1639,7 +1643,7 @@ class SyncHandler(object):
)
# We loop through all room ids, even if there are no new events, in case
- # there are non room events taht we need to notify about.
+ # there are non room events that we need to notify about.
for room_id in sync_result_builder.joined_room_ids:
room_entry = room_to_events.get(room_id, None)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 391bceb0c4..c7bc14c623 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
+from typing import List
from twisted.internet import defer
@@ -257,7 +258,13 @@ class TypingHandler(object):
"typing_key", self._latest_room_serial, rooms=[member.room_id]
)
- async def get_all_typing_updates(self, last_id, current_id):
+ async def get_all_typing_updates(
+ self, last_id: int, current_id: int, limit: int
+ ) -> List[dict]:
+ """Get up to `limit` typing updates between the given tokens, earliest
+ updates first.
+ """
+
if last_id == current_id:
return []
@@ -275,7 +282,7 @@ class TypingHandler(object):
typing = self._room_typing[room_id]
rows.append((serial, room_id, list(typing)))
rows.sort()
- return rows
+ return rows[:limit]
def get_current_token(self):
return self._latest_room_serial
|