diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 829f52eca1..590135d19c 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -44,7 +44,11 @@ class AccountValidityHandler(object):
self._account_validity = self.hs.config.account_validity
- if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
+ if (
+ self._account_validity.enabled
+ and self._account_validity.renew_by_email_enabled
+ and load_jinja2_templates
+ ):
# Don't do email-specific configuration if renewal by email is disabled.
try:
app_name = self.hs.config.email_app_name
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 48a88d3c2a..7860f9625e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,9 +17,11 @@
import logging
import time
import unicodedata
+import urllib.parse
+from typing import Any, Dict, Iterable, List, Optional
import attr
-import bcrypt
+import bcrypt # type: ignore[import]
import pymacaroons
from twisted.internet import defer
@@ -38,9 +40,12 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.module_api import ModuleApi
-from synapse.types import UserID
+from synapse.push.mailer import load_jinja2_templates
+from synapse.types import Requester, UserID
from synapse.util.caches.expiringcache import ExpiringCache
from ._base import BaseHandler
@@ -58,11 +63,11 @@ class AuthHandler(BaseHandler):
"""
super(AuthHandler, self).__init__(hs)
- self.checkers = {} # type: dict[str, UserInteractiveAuthChecker]
+ self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker]
for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
inst = auth_checker_class(hs)
if inst.is_enabled():
- self.checkers[inst.AUTH_TYPE] = inst
+ self.checkers[inst.AUTH_TYPE] = inst # type: ignore
self.bcrypt_rounds = hs.config.bcrypt_rounds
@@ -108,8 +113,20 @@ class AuthHandler(BaseHandler):
self._clock = self.hs.get_clock()
+ # Load the SSO redirect confirmation page HTML template
+ self._sso_redirect_confirm_template = load_jinja2_templates(
+ hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+ )[0]
+
+ self._server_name = hs.config.server_name
+
+ # cast to tuple for use with str.startswith
+ self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
@defer.inlineCallbacks
- def validate_user_via_ui_auth(self, requester, request_body, clientip):
+ def validate_user_via_ui_auth(
+ self, requester: Requester, request_body: Dict[str, Any], clientip: str
+ ):
"""
Checks that the user is who they claim to be, via a UI auth.
@@ -118,11 +135,11 @@ class AuthHandler(BaseHandler):
that it isn't stolen by re-authenticating them.
Args:
- requester (Requester): The user, as given by the access token
+ requester: The user, as given by the access token
- request_body (dict): The body of the request sent by the client
+ request_body: The body of the request sent by the client
- clientip (str): The IP address of the client.
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict]: the parameters for this request (which may
@@ -193,7 +210,9 @@ class AuthHandler(BaseHandler):
return self.checkers.keys()
@defer.inlineCallbacks
- def check_auth(self, flows, clientdict, clientip):
+ def check_auth(
+ self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+ ):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the User-Interactive Auth flow.
@@ -208,14 +227,14 @@ class AuthHandler(BaseHandler):
decorator.
Args:
- flows (list): A list of login flows. Each flow is an ordered list of
- strings representing auth-types. At least one full
- flow must be completed in order for auth to be successful.
+ flows: A list of login flows. Each flow is an ordered list of
+ strings representing auth-types. At least one full
+ flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
- clientip (str): The IP address of the client.
+ clientip: The IP address of the client.
Returns:
defer.Deferred[dict, dict, str]: a deferred tuple of
@@ -235,7 +254,7 @@ class AuthHandler(BaseHandler):
"""
authdict = None
- sid = None
+ sid = None # type: Optional[str]
if clientdict and "auth" in clientdict:
authdict = clientdict["auth"]
del clientdict["auth"]
@@ -268,9 +287,9 @@ class AuthHandler(BaseHandler):
creds = session["creds"]
# check auth type currently being presented
- errordict = {}
+ errordict = {} # type: Dict[str, Any]
if "type" in authdict:
- login_type = authdict["type"]
+ login_type = authdict["type"] # type: str
try:
result = yield self._check_auth_dict(authdict, clientip)
if result:
@@ -311,7 +330,7 @@ class AuthHandler(BaseHandler):
raise InteractiveAuthIncompleteError(ret)
@defer.inlineCallbacks
- def add_oob_auth(self, stagetype, authdict, clientip):
+ def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
@@ -333,7 +352,7 @@ class AuthHandler(BaseHandler):
return True
return False
- def get_session_id(self, clientdict):
+ def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
"""
Gets the session ID for a client given the client dictionary
@@ -341,7 +360,7 @@ class AuthHandler(BaseHandler):
clientdict: The dictionary sent by the client in the request
Returns:
- str|None: The string session ID the client sent. If the client did
+ The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
@@ -351,40 +370,42 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
return sid
- def set_session_data(self, session_id, key, value):
+ def set_session_data(self, session_id: str, key: str, value: Any) -> None:
"""
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- value (any): The data to store
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ value: The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault("serverdict", {})[key] = value
self._save_session(sess)
- def get_session_data(self, session_id, key, default=None):
+ def get_session_data(
+ self, session_id: str, key: str, default: Optional[Any] = None
+ ) -> Any:
"""
Retrieve data stored with set_session_data
Args:
- session_id (string): The ID of this session as returned from check_auth
- key (string): The key to store the data under
- default (any): Value to return if the key has not been set
+ session_id: The ID of this session as returned from check_auth
+ key: The key to store the data under
+ default: Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault("serverdict", {}).get(key, default)
@defer.inlineCallbacks
- def _check_auth_dict(self, authdict, clientip):
+ def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
"""Attempt to validate the auth dict provided by a client
Args:
- authdict (object): auth dict provided by the client
- clientip (str): IP address of the client
+ authdict: auth dict provided by the client
+ clientip: IP address of the client
Returns:
Deferred: result of the stage verification.
@@ -410,10 +431,10 @@ class AuthHandler(BaseHandler):
(canonical_id, callback) = yield self.validate_login(user_id, authdict)
return canonical_id
- def _get_params_recaptcha(self):
+ def _get_params_recaptcha(self) -> dict:
return {"public_key": self.hs.config.recaptcha_public_key}
- def _get_params_terms(self):
+ def _get_params_terms(self) -> dict:
return {
"policies": {
"privacy_policy": {
@@ -430,7 +451,9 @@ class AuthHandler(BaseHandler):
}
}
- def _auth_dict_for_flows(self, flows, session):
+ def _auth_dict_for_flows(
+ self, flows: List[List[str]], session: Dict[str, Any]
+ ) -> Dict[str, Any]:
public_flows = []
for f in flows:
public_flows.append(f)
@@ -440,7 +463,7 @@ class AuthHandler(BaseHandler):
LoginType.TERMS: self._get_params_terms,
}
- params = {}
+ params = {} # type: Dict[str, Any]
for f in public_flows:
for stage in f:
@@ -453,7 +476,13 @@ class AuthHandler(BaseHandler):
"params": params,
}
- def _get_session_info(self, session_id):
+ def _get_session_info(self, session_id: Optional[str]) -> dict:
+ """
+ Gets or creates a session given a session ID.
+
+ The session can be used to track data across multiple requests, e.g. for
+ interactive authentication.
+ """
if session_id not in self.sessions:
session_id = None
@@ -466,7 +495,9 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
- def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
+ def get_access_token_for_user_id(
+ self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
+ ):
"""
Creates a new access token for the user with the given user ID.
@@ -476,11 +507,11 @@ class AuthHandler(BaseHandler):
The device will be recorded in the table if it is not there already.
Args:
- user_id (str): canonical User ID
- device_id (str|None): the device ID to associate with the tokens.
+ user_id: canonical User ID
+ device_id: the device ID to associate with the tokens.
None to leave the tokens unassociated with a device (deprecated:
we should always have a device ID)
- valid_until_ms (int|None): when the token is valid until. None for
+ valid_until_ms: when the token is valid until. None for
no expiry.
Returns:
The access token for the user's session.
@@ -515,13 +546,13 @@ class AuthHandler(BaseHandler):
return access_token
@defer.inlineCallbacks
- def check_user_exists(self, user_id):
+ def check_user_exists(self, user_id: str):
"""
Checks to see if a user with the given id exists. Will check case
insensitively, but return None if there are multiple inexact matches.
Args:
- (unicode|bytes) user_id: complete @user:id
+ user_id: complete @user:id
Returns:
defer.Deferred: (unicode) canonical_user_id, or None if zero or
@@ -536,7 +567,7 @@ class AuthHandler(BaseHandler):
return None
@defer.inlineCallbacks
- def _find_user_id_and_pwd_hash(self, user_id):
+ def _find_user_id_and_pwd_hash(self, user_id: str):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will return None if there are multiple inexact
matches.
@@ -566,7 +597,7 @@ class AuthHandler(BaseHandler):
)
return result
- def get_supported_login_types(self):
+ def get_supported_login_types(self) -> Iterable[str]:
"""Get a the login types supported for the /login API
By default this is just 'm.login.password' (unless password_enabled is
@@ -574,20 +605,20 @@ class AuthHandler(BaseHandler):
other login types.
Returns:
- Iterable[str]: login types
+ login types
"""
return self._supported_login_types
@defer.inlineCallbacks
- def validate_login(self, username, login_submission):
+ def validate_login(self, username: str, login_submission: Dict[str, Any]):
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate
m.login.password auth types.
Args:
- username (str): username supplied by the user
- login_submission (dict): the whole of the login submission
+ username: username supplied by the user
+ login_submission: the whole of the login submission
(including 'type' and other relevant fields)
Returns:
Deferred[str, func]: canonical user id, and optional callback
@@ -675,13 +706,13 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
- def check_password_provider_3pid(self, medium, address, password):
+ def check_password_provider_3pid(self, medium: str, address: str, password: str):
"""Check if a password provider is able to validate a thirdparty login
Args:
- medium (str): The medium of the 3pid (ex. email).
- address (str): The address of the 3pid (ex. jdoe@example.com).
- password (str): The password of the user.
+ medium: The medium of the 3pid (ex. email).
+ address: The address of the 3pid (ex. jdoe@example.com).
+ password: The password of the user.
Returns:
Deferred[(str|None, func|None)]: A tuple of `(user_id,
@@ -709,15 +740,15 @@ class AuthHandler(BaseHandler):
return None, None
@defer.inlineCallbacks
- def _check_local_password(self, user_id, password):
+ def _check_local_password(self, user_id: str, password: str):
"""Authenticate a user against the local password database.
user_id is checked case insensitively, but will return None if there are
multiple inexact matches.
Args:
- user_id (unicode): complete @user:id
- password (unicode): the provided password
+ user_id: complete @user:id
+ password: the provided password
Returns:
Deferred[unicode] the canonical_user_id, or Deferred[None] if
unknown user/bad password
@@ -740,7 +771,7 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
- def validate_short_term_login_token_and_get_user_id(self, login_token):
+ def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
try:
@@ -754,11 +785,11 @@ class AuthHandler(BaseHandler):
return user_id
@defer.inlineCallbacks
- def delete_access_token(self, access_token):
+ def delete_access_token(self, access_token: str):
"""Invalidate a single access token
Args:
- access_token (str): access token to be deleted
+ access_token: access token to be deleted
Returns:
Deferred
@@ -783,15 +814,17 @@ class AuthHandler(BaseHandler):
@defer.inlineCallbacks
def delete_access_tokens_for_user(
- self, user_id, except_token_id=None, device_id=None
+ self,
+ user_id: str,
+ except_token_id: Optional[str] = None,
+ device_id: Optional[str] = None,
):
"""Invalidate access tokens belonging to a user
Args:
- user_id (str): ID of user the tokens belong to
- except_token_id (str|None): access_token ID which should *not* be
- deleted
- device_id (str|None): ID of device the tokens are associated with.
+ user_id: ID of user the tokens belong to
+ except_token_id: access_token ID which should *not* be deleted
+ device_id: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
Returns:
@@ -815,7 +848,7 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def add_threepid(self, user_id, medium, address, validated_at):
+ def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
# check if medium has a valid value
if medium not in ["email", "msisdn"]:
raise SynapseError(
@@ -841,19 +874,20 @@ class AuthHandler(BaseHandler):
)
@defer.inlineCallbacks
- def delete_threepid(self, user_id, medium, address, id_server=None):
+ def delete_threepid(
+ self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
+ ):
"""Attempts to unbind the 3pid on the identity servers and deletes it
from the local database.
Args:
- user_id (str)
- medium (str)
- address (str)
- id_server (str|None): Use the given identity server when unbinding
+ user_id: ID of user to remove the 3pid from.
+ medium: The medium of the 3pid being removed: "email" or "msisdn".
+ address: The 3pid address to remove.
+ id_server: Use the given identity server when unbinding
any threepids. If None then will attempt to unbind using the
identity server specified when binding (if known).
-
Returns:
Deferred[bool]: Returns True if successfully unbound the 3pid on
the identity server, False if identity server doesn't support the
@@ -872,17 +906,18 @@ class AuthHandler(BaseHandler):
yield self.store.user_delete_threepid(user_id, medium, address)
return result
- def _save_session(self, session):
+ def _save_session(self, session: Dict[str, Any]) -> None:
+ """Update the last used time on the session to now and add it back to the session store."""
# TODO: Persistent storage
logger.debug("Saving session %s", session)
session["last_used"] = self.hs.get_clock().time_msec()
self.sessions[session["id"]] = session
- def hash(self, password):
+ def hash(self, password: str):
"""Computes a secure hash of password.
Args:
- password (unicode): Password to hash.
+ password: Password to hash.
Returns:
Deferred(unicode): Hashed password.
@@ -899,12 +934,12 @@ class AuthHandler(BaseHandler):
return defer_to_thread(self.hs.get_reactor(), _do_hash)
- def validate_hash(self, password, stored_hash):
+ def validate_hash(self, password: str, stored_hash: bytes):
"""Validates that self.hash(password) == stored_hash.
Args:
- password (unicode): Password to hash.
- stored_hash (bytes): Expected hash value.
+ password: Password to hash.
+ stored_hash: Expected hash value.
Returns:
Deferred(bool): Whether self.hash(password) == stored_hash.
@@ -927,13 +962,74 @@ class AuthHandler(BaseHandler):
else:
return defer.succeed(False)
+ def complete_sso_login(
+ self,
+ registered_user_id: str,
+ request: SynapseRequest,
+ client_redirect_url: str,
+ ):
+ """Having figured out a mxid for this user, complete the HTTP request
+
+ Args:
+ registered_user_id: The registered user ID to complete SSO login for.
+ request: The request to complete.
+ client_redirect_url: The URL to which to redirect the user at the end of the
+ process.
+ """
+ # Create a login token
+ login_token = self.macaroon_gen.generate_short_term_login_token(
+ registered_user_id
+ )
+
+ # Append the login token to the original redirect URL (i.e. with its query
+ # parameters kept intact) to build the URL to which the template needs to
+ # redirect the users once they have clicked on the confirmation link.
+ redirect_url = self.add_query_param_to_url(
+ client_redirect_url, "loginToken", login_token
+ )
+
+ # if the client is whitelisted, we can redirect straight to it
+ if client_redirect_url.startswith(self._whitelisted_sso_clients):
+ request.redirect(redirect_url)
+ finish_request(request)
+ return
+
+ # Otherwise, serve the redirect confirmation page.
+
+ # Remove the query parameters from the redirect URL to get a shorter version of
+ # it. This is only to display a human-readable URL in the template, but not the
+ # URL we redirect users to.
+ redirect_url_no_params = client_redirect_url.split("?")[0]
+
+ html = self._sso_redirect_confirm_template.render(
+ display_url=redirect_url_no_params,
+ redirect_url=redirect_url,
+ server_name=self._server_name,
+ ).encode("utf-8")
+
+ request.setResponseCode(200)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(b"Content-Length", b"%d" % (len(html),))
+ request.write(html)
+ finish_request(request)
+
+ @staticmethod
+ def add_query_param_to_url(url: str, param_name: str, param: Any):
+ url_parts = list(urllib.parse.urlparse(url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({param_name: param})
+ url_parts[4] = urllib.parse.urlencode(query)
+ return urllib.parse.urlunparse(url_parts)
+
@attr.s
class MacaroonGenerator(object):
hs = attr.ib()
- def generate_access_token(self, user_id, extra_caveats=None):
+ def generate_access_token(
+ self, user_id: str, extra_caveats: Optional[List[str]] = None
+ ) -> str:
extra_caveats = extra_caveats or []
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
@@ -946,16 +1042,9 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat(caveat)
return macaroon.serialize()
- def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
- """
-
- Args:
- user_id (unicode):
- duration_in_ms (int):
-
- Returns:
- unicode
- """
+ def generate_short_term_login_token(
+ self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ ) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
@@ -963,12 +1052,12 @@ class MacaroonGenerator(object):
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
- def generate_delete_pusher_token(self, user_id):
+ def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
return macaroon.serialize()
- def _generate_base_macaroon(self, user_id):
+ def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index a514c30714..993499f446 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -125,8 +125,14 @@ class DeviceWorkerHandler(BaseHandler):
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
+
+ tracked_users = set(users_who_share_room)
+
+ # Always tell the user about their own devices
+ tracked_users.add(user_id)
+
changed = yield self.store.get_users_whose_devices_changed(
- from_token.device_list_key, users_who_share_room
+ from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
@@ -456,7 +462,11 @@ class DeviceHandler(DeviceWorkerHandler):
room_ids = yield self.store.get_rooms_for_user(user_id)
- yield self.notifier.on_new_event("device_list_key", position, rooms=room_ids)
+ # specify the user ID too since the user should always get their own device list
+ # updates, even if they aren't in any rooms.
+ yield self.notifier.on_new_event(
+ "device_list_key", position, users=[user_id], rooms=room_ids
+ )
if hosts:
logger.info(
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 0b23ca919a..1d842c369b 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -13,11 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import collections
import logging
import string
-from typing import List
+from typing import Iterable, List, Optional
from twisted.internet import defer
@@ -30,6 +28,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
+from synapse.appservice import ApplicationService
from synapse.types import Requester, RoomAlias, UserID, get_domain_from_id
from ._base import BaseHandler
@@ -57,7 +56,13 @@ class DirectoryHandler(BaseHandler):
self.spam_checker = hs.get_spam_checker()
@defer.inlineCallbacks
- def _create_association(self, room_alias, room_id, servers=None, creator=None):
+ def _create_association(
+ self,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Optional[Iterable[str]] = None,
+ creator: Optional[str] = None,
+ ):
# general association creation for both human users and app services
for wchar in string.whitespace:
@@ -83,17 +88,21 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def create_association(
- self, requester, room_alias, room_id, servers=None, check_membership=True,
+ self,
+ requester: Requester,
+ room_alias: RoomAlias,
+ room_id: str,
+ servers: Optional[List[str]] = None,
+ check_membership: bool = True,
):
"""Attempt to create a new alias
Args:
- requester (Requester)
- room_alias (RoomAlias)
- room_id (str)
- servers (list[str]|None): List of servers that others servers
- should try and join via
- check_membership (bool): Whether to check if the user is in the room
+ requester
+ room_alias
+ room_id
+ servers: Iterable of servers that others servers should try and join via
+ check_membership: Whether to check if the user is in the room
before the alias can be set (if the server's config requires it).
Returns:
@@ -147,15 +156,15 @@ class DirectoryHandler(BaseHandler):
yield self._create_association(room_alias, room_id, servers, creator=user_id)
@defer.inlineCallbacks
- def delete_association(self, requester, room_alias):
+ def delete_association(self, requester: Requester, room_alias: RoomAlias):
"""Remove an alias from the directory
(this is only meant for human users; AS users should call
delete_appservice_association)
Args:
- requester (Requester):
- room_alias (RoomAlias):
+ requester
+ room_alias
Returns:
Deferred[unicode]: room id that the alias used to point to
@@ -191,16 +200,16 @@ class DirectoryHandler(BaseHandler):
room_id = yield self._delete_association(room_alias)
try:
- yield self._update_canonical_alias(
- requester, requester.user.to_string(), room_id, room_alias
- )
+ yield self._update_canonical_alias(requester, user_id, room_id, room_alias)
except AuthError as e:
logger.info("Failed to update alias events: %s", e)
return room_id
@defer.inlineCallbacks
- def delete_appservice_association(self, service, room_alias):
+ def delete_appservice_association(
+ self, service: ApplicationService, room_alias: RoomAlias
+ ):
if not service.is_interested_in_alias(room_alias.to_string()):
raise SynapseError(
400,
@@ -210,7 +219,7 @@ class DirectoryHandler(BaseHandler):
yield self._delete_association(room_alias)
@defer.inlineCallbacks
- def _delete_association(self, room_alias):
+ def _delete_association(self, room_alias: RoomAlias):
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
@@ -219,7 +228,7 @@ class DirectoryHandler(BaseHandler):
return room_id
@defer.inlineCallbacks
- def get_association(self, room_alias):
+ def get_association(self, room_alias: RoomAlias):
room_id = None
if self.hs.is_mine(room_alias):
result = yield self.get_association_from_room_alias(room_alias)
@@ -284,7 +293,9 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
+ def _update_canonical_alias(
+ self, requester: Requester, user_id: str, room_id: str, room_alias: RoomAlias
+ ):
"""
Send an updated canonical alias event if the removed alias was set as
the canonical alias or listed in the alt_aliases field.
@@ -307,15 +318,17 @@ class DirectoryHandler(BaseHandler):
send_update = True
content.pop("alias", "")
- # Filter alt_aliases for the removed alias.
- alt_aliases = content.pop("alt_aliases", None)
- # If the aliases are not a list (or not found) do not attempt to modify
- # the list.
- if isinstance(alt_aliases, collections.Sequence):
+ # Filter the alt_aliases property for the removed alias. Note that the
+ # value is not modified if alt_aliases is of an unexpected form.
+ alt_aliases = content.get("alt_aliases")
+ if isinstance(alt_aliases, (list, tuple)) and alias_str in alt_aliases:
send_update = True
alt_aliases = [alias for alias in alt_aliases if alias != alias_str]
+
if alt_aliases:
content["alt_aliases"] = alt_aliases
+ else:
+ del content["alt_aliases"]
if send_update:
yield self.event_creation_handler.create_and_send_nonmember_event(
@@ -331,7 +344,7 @@ class DirectoryHandler(BaseHandler):
)
@defer.inlineCallbacks
- def get_association_from_room_alias(self, room_alias):
+ def get_association_from_room_alias(self, room_alias: RoomAlias):
result = yield self.store.get_association_from_room_alias(room_alias)
if not result:
# Query AS to see if it exists
@@ -339,7 +352,7 @@ class DirectoryHandler(BaseHandler):
result = yield as_handler.query_room_alias_exists(room_alias)
return result
- def can_modify_alias(self, alias, user_id=None):
+ def can_modify_alias(self, alias: RoomAlias, user_id: Optional[str] = None):
# Any application service "interested" in an alias they are regexing on
# can modify the alias.
# Users can only modify the alias if ALL the interested services have
@@ -360,22 +373,42 @@ class DirectoryHandler(BaseHandler):
return defer.succeed(True)
@defer.inlineCallbacks
- def _user_can_delete_alias(self, alias, user_id):
+ def _user_can_delete_alias(self, alias: RoomAlias, user_id: str):
+ """Determine whether a user can delete an alias.
+
+ One of the following must be true:
+
+ 1. The user created the alias.
+ 2. The user is a server administrator.
+ 3. The user has a power-level sufficient to send a canonical alias event
+ for the current room.
+
+ """
creator = yield self.store.get_room_alias_creator(alias.to_string())
if creator is not None and creator == user_id:
return True
- is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
- return is_admin
+ # Resolve the alias to the corresponding room.
+ room_mapping = yield self.get_association(alias)
+ room_id = room_mapping["room_id"]
+ if not room_id:
+ return False
+
+ res = yield self.auth.check_can_change_room_list(
+ room_id, UserID.from_string(user_id)
+ )
+ return res
@defer.inlineCallbacks
- def edit_published_room_list(self, requester, room_id, visibility):
+ def edit_published_room_list(
+ self, requester: Requester, room_id: str, visibility: str
+ ):
"""Edit the entry of the room in the published room list.
requester
- room_id (str)
- visibility (str): "public" or "private"
+ room_id
+ visibility: "public" or "private"
"""
user_id = requester.user.to_string()
@@ -400,7 +433,15 @@ class DirectoryHandler(BaseHandler):
if room is None:
raise SynapseError(400, "Unknown room")
- yield self.auth.check_can_change_room_list(room_id, requester.user)
+ can_change_room_list = yield self.auth.check_can_change_room_list(
+ room_id, requester.user
+ )
+ if not can_change_room_list:
+ raise AuthError(
+ 403,
+ "This server requires you to be a moderator in the room to"
+ " edit its room list entry",
+ )
making_public = visibility == "public"
if making_public:
@@ -421,16 +462,16 @@ class DirectoryHandler(BaseHandler):
@defer.inlineCallbacks
def edit_published_appservice_room_list(
- self, appservice_id, network_id, room_id, visibility
+ self, appservice_id: str, network_id: str, room_id: str, visibility: str
):
"""Add or remove a room from the appservice/network specific public
room list.
Args:
- appservice_id (str): ID of the appservice that owns the list
- network_id (str): The ID of the network the list is associated with
- room_id (str)
- visibility (str): either "public" or "private"
+ appservice_id: ID of the appservice that owns the list
+ network_id: The ID of the network the list is associated with
+ room_id
+ visibility: either "public" or "private"
"""
if visibility not in ["public", "private"]:
raise SynapseError(400, "Invalid visibility setting")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 95a9d71f41..8f1bc0323c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -54,19 +54,23 @@ class E2eKeysHandler(object):
self._edu_updater = SigningKeyEduUpdater(hs, self)
+ federation_registry = hs.get_federation_registry()
+
self._is_master = hs.config.worker_app is None
if not self._is_master:
self._user_device_resync_client = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
+ else:
+ # Only register this edu handler on master as it requires writing
+ # device updates to the db
+ #
+ # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
+ federation_registry.register_edu_handler(
+ "org.matrix.signing_key_update",
+ self._edu_updater.incoming_signing_key_update,
+ )
- federation_registry = hs.get_federation_registry()
-
- # FIXME: switch to m.signing_key_update when MSC1756 is merged into the spec
- federation_registry.register_edu_handler(
- "org.matrix.signing_key_update",
- self._edu_updater.incoming_signing_key_update,
- )
# doesn't really work as part of the generic query API, because the
# query request requires an object POST, but we abuse the
# "query handler" interface.
@@ -170,8 +174,8 @@ class E2eKeysHandler(object):
"""This is called when we are querying the device list of a user on
a remote homeserver and their device list is not in the device list
cache. If we share a room with this user and we're not querying for
- specific user we will update the cache
- with their device list."""
+ specific user we will update the cache with their device list.
+ """
destination_query = remote_queries_not_in_cache[destination]
@@ -957,13 +961,19 @@ class E2eKeysHandler(object):
return signature_list, failures
@defer.inlineCallbacks
- def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None):
- """Fetch the cross-signing public key from storage and interpret it.
+ def _get_e2e_cross_signing_verify_key(
+ self, user_id: str, key_type: str, from_user_id: str = None
+ ):
+ """Fetch locally or remotely query for a cross-signing public key.
+
+ First, attempt to fetch the cross-signing public key from storage.
+ If that fails, query the keys from the homeserver they belong to
+ and update our local copy.
Args:
- user_id (str): the user whose key should be fetched
- key_type (str): the type of key to fetch
- from_user_id (str): the user that we are fetching the keys for.
+ user_id: the user whose key should be fetched
+ key_type: the type of key to fetch
+ from_user_id: the user that we are fetching the keys for.
This affects what signatures are fetched.
Returns:
@@ -972,16 +982,140 @@ class E2eKeysHandler(object):
Raises:
NotFoundError: if the key is not found
+ SynapseError: if `user_id` is invalid
"""
+ user = UserID.from_string(user_id)
key = yield self.store.get_e2e_cross_signing_key(
user_id, key_type, from_user_id
)
+
+ if key:
+ # We found a copy of this key in our database. Decode and return it
+ key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+ return key, key_id, verify_key
+
+ # If we couldn't find the key locally, and we're looking for keys of
+ # another user then attempt to fetch the missing key from the remote
+ # user's server.
+ #
+ # We may run into this in possible edge cases where a user tries to
+ # cross-sign a remote user, but does not share any rooms with them yet.
+ # Thus, we would not have their key list yet. We instead fetch the key,
+ # store it and notify clients of new, associated device IDs.
+ if self.is_mine(user) or key_type not in ["master", "self_signing"]:
+ # Note that master and self_signing keys are the only cross-signing keys we
+ # can request over federation
+ raise NotFoundError("No %s key found for %s" % (key_type, user_id))
+
+ (
+ key,
+ key_id,
+ verify_key,
+ ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+
if key is None:
- logger.debug("no %s key found for %s", key_type, user_id)
raise NotFoundError("No %s key found for %s" % (key_type, user_id))
- key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+
return key, key_id, verify_key
+ @defer.inlineCallbacks
+ def _retrieve_cross_signing_keys_for_remote_user(
+ self, user: UserID, desired_key_type: str,
+ ):
+ """Queries cross-signing keys for a remote user and saves them to the database
+
+ Only the key specified by `key_type` will be returned, while all retrieved keys
+ will be saved regardless
+
+ Args:
+ user: The user to query remote keys for
+ desired_key_type: The type of key to receive. One of "master", "self_signing"
+
+ Returns:
+ Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
+ of the retrieved key content, the key's ID and the matching VerifyKey.
+ If the key cannot be retrieved, all values in the tuple will instead be None.
+ """
+ try:
+ remote_result = yield self.federation.query_user_devices(
+ user.domain, user.to_string()
+ )
+ except Exception as e:
+ logger.warning(
+ "Unable to query %s for cross-signing keys of user %s: %s %s",
+ user.domain,
+ user.to_string(),
+ type(e),
+ e,
+ )
+ return None, None, None
+
+ # Process each of the retrieved cross-signing keys
+ desired_key = None
+ desired_key_id = None
+ desired_verify_key = None
+ retrieved_device_ids = []
+ for key_type in ["master", "self_signing"]:
+ key_content = remote_result.get(key_type + "_key")
+ if not key_content:
+ continue
+
+ # Ensure these keys belong to the correct user
+ if "user_id" not in key_content:
+ logger.warning(
+ "Invalid %s key retrieved, missing user_id field: %s",
+ key_type,
+ key_content,
+ )
+ continue
+ if user.to_string() != key_content["user_id"]:
+ logger.warning(
+ "Found %s key of user %s when querying for keys of user %s",
+ key_type,
+ key_content["user_id"],
+ user.to_string(),
+ )
+ continue
+
+ # Validate the key contents
+ try:
+ # verify_key is a VerifyKey from signedjson, which uses
+ # .version to denote the portion of the key ID after the
+ # algorithm and colon, which is the device ID
+ key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
+ except ValueError as e:
+ logger.warning(
+ "Invalid %s key retrieved: %s - %s %s",
+ key_type,
+ key_content,
+ type(e),
+ e,
+ )
+ continue
+
+ # Note down the device ID attached to this key
+ retrieved_device_ids.append(verify_key.version)
+
+ # If this is the desired key type, save it and its ID/VerifyKey
+ if key_type == desired_key_type:
+ desired_key = key_content
+ desired_verify_key = verify_key
+ desired_key_id = key_id
+
+ # At the same time, store this key in the db for subsequent queries
+ yield self.store.set_e2e_cross_signing_key(
+ user.to_string(), key_type, key_content
+ )
+
+ # Notify clients that new devices for this user have been discovered
+ if retrieved_device_ids:
+ # XXX is this necessary?
+ yield self.device_handler.notify_device_update(
+ user.to_string(), retrieved_device_ids
+ )
+
+ return desired_key, desired_key_id, desired_verify_key
+
def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
"""Check a cross-signing key uploaded by a user. Performs some basic sanity
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index f1b4424a02..9abaf13b8f 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -207,6 +207,13 @@ class E2eRoomKeysHandler(object):
changed = False # if anything has changed, we need to update the etag
for room_id, room in iteritems(room_keys["rooms"]):
for session_id, room_key in iteritems(room["sessions"]):
+ if not isinstance(room_key["is_verified"], bool):
+ msg = (
+ "is_verified must be a boolean in keys for session %s in"
+ "room %s" % (session_id, room_id)
+ )
+ raise SynapseError(400, msg, Codes.INVALID_PARAM)
+
log_kv(
{
"message": "Trying to upload room key",
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index a0103addd3..b743fc2dcc 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -160,7 +160,7 @@ class MessageHandler(object):
raise NotFoundError("Can't find event for token %s" % (at_token,))
visible_events = yield filter_events_for_client(
- self.storage, user_id, last_events, apply_retention_policies=False
+ self.storage, user_id, last_events, filter_send_to_client=False
)
event = last_events[0]
@@ -888,19 +888,60 @@ class EventCreationHandler(object):
yield self.base_handler.maybe_kick_guest_users(event, context)
if event.type == EventTypes.CanonicalAlias:
- # Check the alias is acually valid (at this time at least)
+ # Validate a newly added alias or newly added alt_aliases.
+
+ original_alias = None
+ original_alt_aliases = set()
+
+ original_event_id = event.unsigned.get("replaces_state")
+ if original_event_id:
+ original_event = yield self.store.get_event(original_event_id)
+
+ if original_event:
+ original_alias = original_event.content.get("alias", None)
+ original_alt_aliases = original_event.content.get("alt_aliases", [])
+
+ # Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
- if room_alias_str:
+ directory_handler = self.hs.get_handlers().directory_handler
+ if room_alias_str and room_alias_str != original_alias:
room_alias = RoomAlias.from_string(room_alias_str)
- directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (room_alias_str,),
+ Codes.BAD_ALIAS,
)
+ # Check that alt_aliases is the proper form.
+ alt_aliases = event.content.get("alt_aliases", [])
+ if not isinstance(alt_aliases, (list, tuple)):
+ raise SynapseError(
+ 400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
+ )
+
+ # If the old version of alt_aliases is of an unknown form,
+ # completely replace it.
+ if not isinstance(original_alt_aliases, (list, tuple)):
+ original_alt_aliases = []
+
+ # Check that each alias is currently valid.
+ new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
+ if new_alt_aliases:
+ for alias_str in new_alt_aliases:
+ room_alias = RoomAlias.from_string(alias_str)
+ mapping = yield directory_handler.get_association(room_alias)
+
+ if mapping["room_id"] != event.room_id:
+ raise SynapseError(
+ 400,
+ "Room alias %s does not point to the room"
+ % (room_alias_str,),
+ Codes.BAD_ALIAS,
+ )
+
federation_handler = self.hs.get_handlers().federation_handler
if event.type == EventTypes.Member:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 8ee870f0bb..f580ab2e9f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -292,16 +292,6 @@ class RoomCreationHandler(BaseHandler):
except AuthError as e:
logger.warning("Unable to update PLs in old room: %s", e)
- new_pl_content = copy_power_levels_contents(old_room_pl_state.content)
-
- # pre-msc2260 rooms may not have the right setting for aliases. If no other
- # value is set, set it now.
- events_default = new_pl_content.get("events_default", 0)
- new_pl_content.setdefault("events", {}).setdefault(
- EventTypes.Aliases, events_default
- )
-
- logger.debug("Setting correct PLs in new room to %s", new_pl_content)
yield self.event_creation_handler.create_and_send_nonmember_event(
requester,
{
@@ -309,7 +299,7 @@ class RoomCreationHandler(BaseHandler):
"state_key": "",
"room_id": new_room_id,
"sender": requester.user.to_string(),
- "content": new_pl_content,
+ "content": old_room_pl_state.content,
},
ratelimit=False,
)
@@ -814,10 +804,6 @@ class RoomCreationHandler(BaseHandler):
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
- # MSC2260: Allow everybody to send alias events by default
- # This will be reudundant on pre-MSC2260 rooms, since the
- # aliases event is special-cased.
- EventTypes.Aliases: 0,
EventTypes.Tombstone: 100,
EventTypes.ServerACL: 100,
},
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 7f411b53b9..72c109981b 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -23,9 +23,9 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
+from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.module_api import ModuleApi
-from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
@@ -48,7 +48,7 @@ class Saml2SessionData:
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._sso_auth_handler = SSOAuthHandler(hs)
+ self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
self._clock = hs.get_clock()
@@ -74,6 +74,8 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
+ self._error_html_content = hs.config.saml2_error_html_content
+
def handle_redirect_request(self, client_redirect_url):
"""Handle an incoming request to /login/sso/redirect
@@ -115,8 +117,23 @@ class SamlHandler:
# the dict.
self.expire_sessions()
- user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
- self._sso_auth_handler.complete_sso_login(user_id, request, relay_state)
+ try:
+ user_id = await self._map_saml_response_to_user(resp_bytes, relay_state)
+ except Exception as e:
+ # If decoding the response or mapping it to a user failed, then log the
+ # error and tell the user that something went wrong.
+ logger.error(e)
+
+ request.setResponseCode(400)
+ request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+ request.setHeader(
+ b"Content-Length", b"%d" % (len(self._error_html_content),)
+ )
+ request.write(self._error_html_content.encode("utf8"))
+ finish_request(request)
+ return
+
+ self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(self, resp_bytes, client_redirect_url):
try:
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index d90c9e0108..12657ca698 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -13,10 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import Optional
from twisted.internet import defer
from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.types import Requester
from ._base import BaseHandler
@@ -32,14 +34,17 @@ class SetPasswordHandler(BaseHandler):
self._device_handler = hs.get_device_handler()
@defer.inlineCallbacks
- def set_password(self, user_id, newpassword, requester=None):
+ def set_password(
+ self,
+ user_id: str,
+ new_password: str,
+ logout_devices: bool,
+ requester: Optional[Requester] = None,
+ ):
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
- password_hash = yield self._auth_handler.hash(newpassword)
-
- except_device_id = requester.device_id if requester else None
- except_access_token_id = requester.access_token_id if requester else None
+ password_hash = yield self._auth_handler.hash(new_password)
try:
yield self.store.user_set_password_hash(user_id, password_hash)
@@ -48,14 +53,18 @@ class SetPasswordHandler(BaseHandler):
raise SynapseError(404, "Unknown user", Codes.NOT_FOUND)
raise e
- # we want to log out all of the user's other sessions. First delete
- # all his other devices.
- yield self._device_handler.delete_all_devices_for_user(
- user_id, except_device_id=except_device_id
- )
-
- # and now delete any access tokens which weren't associated with
- # devices (or were associated with this device).
- yield self._auth_handler.delete_access_tokens_for_user(
- user_id, except_token_id=except_access_token_id
- )
+ # Optionally, log out all of the user's other sessions.
+ if logout_devices:
+ except_device_id = requester.device_id if requester else None
+ except_access_token_id = requester.access_token_id if requester else None
+
+ # First delete all of their other devices.
+ yield self._device_handler.delete_all_devices_for_user(
+ user_id, except_device_id=except_device_id
+ )
+
+ # and now delete any access tokens which weren't associated with
+ # devices (or were associated with this device).
+ yield self._auth_handler.delete_access_tokens_for_user(
+ user_id, except_token_id=except_access_token_id
+ )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 669dbc8a48..cfd5dfc9e5 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1143,9 +1143,14 @@ class SyncHandler(object):
user_id
)
+ tracked_users = set(users_who_share_room)
+
+ # Always tell the user about their own devices
+ tracked_users.add(user_id)
+
# Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed(
- since_token.device_list_key, users_who_share_room
+ since_token.device_list_key, tracked_users
)
# Step 1b, check for newly joined rooms
|