diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh
index cdb77b556c..9905c4bc4f 100755
--- a/.buildkite/scripts/test_old_deps.sh
+++ b/.buildkite/scripts/test_old_deps.sh
@@ -6,7 +6,7 @@
set -ex
apt-get update
-apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev zlib1g-dev tox
+apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8"
diff --git a/changelog.d/8565.misc b/changelog.d/8565.misc
new file mode 100644
index 0000000000..7bef422618
--- /dev/null
+++ b/changelog.d/8565.misc
@@ -0,0 +1 @@
+Simplify the way the `HomeServer` object caches its internal attributes.
diff --git a/changelog.d/8799.bugfix b/changelog.d/8799.bugfix
new file mode 100644
index 0000000000..a7e6b3556d
--- /dev/null
+++ b/changelog.d/8799.bugfix
@@ -0,0 +1 @@
+Allow per-room profiles to be used for the server notice user.
diff --git a/changelog.d/8800.misc b/changelog.d/8800.misc
new file mode 100644
index 0000000000..57cca8fee5
--- /dev/null
+++ b/changelog.d/8800.misc
@@ -0,0 +1 @@
+Add additional error checking for OpenID Connect and SAML mapping providers.
diff --git a/changelog.d/8804.feature b/changelog.d/8804.feature
new file mode 100644
index 0000000000..a907c8106c
--- /dev/null
+++ b/changelog.d/8804.feature
@@ -0,0 +1 @@
+Allow Date header through CORS. Contributed by Nicolas Chamo.
diff --git a/changelog.d/8809.misc b/changelog.d/8809.misc
new file mode 100644
index 0000000000..bbf83cf18d
--- /dev/null
+++ b/changelog.d/8809.misc
@@ -0,0 +1 @@
+Remove unnecessary function arguments and add typing to several membership replication classes.
\ No newline at end of file
diff --git a/changelog.d/8819.misc b/changelog.d/8819.misc
new file mode 100644
index 0000000000..a5793273a5
--- /dev/null
+++ b/changelog.d/8819.misc
@@ -0,0 +1 @@
+Add tests for `password_auth_provider`s.
diff --git a/changelog.d/8820.feature b/changelog.d/8820.feature
new file mode 100644
index 0000000000..9e35861b11
--- /dev/null
+++ b/changelog.d/8820.feature
@@ -0,0 +1 @@
+Add a config option, `push.group_by_unread_count`, which controls whether unread message counts in push notifications are defined as "the number of rooms with unread messages" or "total unread messages".
diff --git a/changelog.d/8833.removal b/changelog.d/8833.removal
new file mode 100644
index 0000000000..5c2d195f94
--- /dev/null
+++ b/changelog.d/8833.removal
@@ -0,0 +1 @@
+Disable pretty printing JSON responses for curl. Users who want pretty-printed output should use [jq](https://stedolan.github.io/jq/) in combination with curl. Contributed by @tulir.
diff --git a/changelog.d/8835.bugfix b/changelog.d/8835.bugfix
new file mode 100644
index 0000000000..446d04aa55
--- /dev/null
+++ b/changelog.d/8835.bugfix
@@ -0,0 +1 @@
+Fix minor long-standing bug in login, where we would offer the `password` login type if a custom auth provider supported it, even if password login was disabled.
diff --git a/changelog.d/8843.feature b/changelog.d/8843.feature
new file mode 100644
index 0000000000..824d46d5aa
--- /dev/null
+++ b/changelog.d/8843.feature
@@ -0,0 +1 @@
+Add `force_purge` option to delete-room admin api.
diff --git a/changelog.d/8845.misc b/changelog.d/8845.misc
new file mode 100644
index 0000000000..7db1c31520
--- /dev/null
+++ b/changelog.d/8845.misc
@@ -0,0 +1 @@
+Drop redundant database index on `event_json`.
diff --git a/changelog.d/8847.misc b/changelog.d/8847.misc
new file mode 100644
index 0000000000..5028997b04
--- /dev/null
+++ b/changelog.d/8847.misc
@@ -0,0 +1 @@
+Simplify `uk.half-shot.msc2778.login.application_service` login handler.
diff --git a/changelog.d/8848.bugfix b/changelog.d/8848.bugfix
new file mode 100644
index 0000000000..499e66f05b
--- /dev/null
+++ b/changelog.d/8848.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused Synapse to require unspecified parameters during user-interactive authentication.
diff --git a/changelog.d/8849.misc b/changelog.d/8849.misc
new file mode 100644
index 0000000000..3dd496ce61
--- /dev/null
+++ b/changelog.d/8849.misc
@@ -0,0 +1 @@
+Refactor `password_auth_provider` support code.
diff --git a/changelog.d/8850.misc b/changelog.d/8850.misc
new file mode 100644
index 0000000000..4b54b8dd87
--- /dev/null
+++ b/changelog.d/8850.misc
@@ -0,0 +1 @@
+Add missing `ordering` to background database updates.
diff --git a/changelog.d/8851.misc b/changelog.d/8851.misc
new file mode 100644
index 0000000000..7bef422618
--- /dev/null
+++ b/changelog.d/8851.misc
@@ -0,0 +1 @@
+Simplify the way the `HomeServer` object caches its internal attributes.
diff --git a/changelog.d/8854.misc b/changelog.d/8854.misc
new file mode 100644
index 0000000000..5895df2d5c
--- /dev/null
+++ b/changelog.d/8854.misc
@@ -0,0 +1 @@
+Allow for specifying a room version when creating a room in unit tests via `RestHelper.create_room_as`.
\ No newline at end of file
diff --git a/changelog.d/8855.feature b/changelog.d/8855.feature
new file mode 100644
index 0000000000..77f7fe4e5d
--- /dev/null
+++ b/changelog.d/8855.feature
@@ -0,0 +1 @@
+Add support for re-trying generation of a localpart for OpenID Connect mapping providers.
diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md
index 0c05b0ed55..004a802e17 100644
--- a/docs/admin_api/rooms.md
+++ b/docs/admin_api/rooms.md
@@ -382,7 +382,7 @@ the new room. Users on other servers will be unaffected.
The API is:
-```json
+```
POST /_synapse/admin/v1/rooms/<room_id>/delete
```
@@ -439,6 +439,10 @@ The following JSON body parameters are available:
future attempts to join the room. Defaults to `false`.
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
Defaults to `true`.
+* `force_purge` - Optional, and ignored unless `purge` is `true`. If set to `true`, it
+ will force a purge to go ahead even if there are local users still in the room. Do not
+ use this unless a regular `purge` operation fails, as it could leave those users'
+ clients in a confused state.
The JSON body must not be empty. The body must be at least `{}`.
diff --git a/docs/password_auth_providers.md b/docs/password_auth_providers.md
index 7d98d9f255..d2cdb9b2f4 100644
--- a/docs/password_auth_providers.md
+++ b/docs/password_auth_providers.md
@@ -26,6 +26,7 @@ Password auth provider classes must provide the following methods:
It should perform any appropriate sanity checks on the provided
configuration, and return an object which is then passed into
+ `__init__`.
This method should have the `@staticmethod` decoration.
diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml
index df0f3e1d8e..394eb9a3ff 100644
--- a/docs/sample_config.yaml
+++ b/docs/sample_config.yaml
@@ -2271,6 +2271,16 @@ push:
#
#include_content: false
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
+
# Spam checkers are third-party modules that can block specific actions
# of local users, such as creating rooms and registering undesirable
diff --git a/mypy.ini b/mypy.ini
index a5503abe26..3c8d303064 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -80,6 +80,7 @@ files =
synapse/util/metrics.py,
tests/replication,
tests/test_utils,
+ tests/handlers/test_password_providers.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_stream_change_cache.py
diff --git a/synapse/config/push.py b/synapse/config/push.py
index a71baac89c..3adbfb73e6 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -23,6 +23,9 @@ class PushConfig(Config):
def read_config(self, config, **kwargs):
push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
+ self.push_group_unread_count_by_room = push_config.get(
+ "group_unread_count_by_room", True
+ )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@@ -68,4 +71,14 @@ class PushConfig(Config):
# include the event ID and room ID in push notification payloads.
#
#include_content: false
+
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
"""
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 5163afd86c..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = []
- for module, config in hs.config.password_providers:
- try:
- self.password_providers.append(
- module(config=config, account_handler=account_handler)
- )
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
+ self.password_providers = [
+ PasswordProvider.load(module, config, account_handler)
+ for module, config in hs.config.password_providers
+ ]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
+
+ # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if self._password_enabled:
+ if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
+
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
+
+ if not self._password_enabled:
+ login_types.remove(LoginType.PASSWORD)
+
self._supported_login_types = login_types
+
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # Ratelimitier for failed /login attempts
+ self._failed_login_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
+
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
- # build a v1-login-style dict out of the authdict and fall back to the
- # v1 code
- user_id = authdict.get("user")
-
- if user_id is None:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
- (canonical_id, callback) = await self.validate_login(user_id, authdict)
+ # fall back to the v1 login flow
+ canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -824,17 +830,17 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, username: str, login_submission: Dict[str, Any]
+ self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
- Also used by the user-interactive auth flow to validate
- m.login.password auth types.
+ Also used by the user-interactive auth flow to validate auth types which don't
+ have an explicit UIA handler, including m.password.auth.
Args:
- username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
+ ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -843,38 +849,160 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
-
- if username.startswith("@"):
- qualified_user_id = username
- else:
- qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
login_type = login_submission.get("type")
- known_login_type = False
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+ # ideally, we wouldn't be checking the identifier unless we know we have a login
+ # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+ #
+ # But the auth providers' check_auth interface requires a username, so in
+ # practice we can only support login methods which we can map to a username
+ # anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
-
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
- if not password:
- raise SynapseError(400, "Missing parameter: password")
+ if not isinstance(password, str):
+ raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
- for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- is_valid = await provider.check_password(qualified_user_id, password)
- if is_valid:
- return qualified_user_id, None
+ # map old-school login fields into new-school "identifier" fields.
+ identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+ login_submission
+ )
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
+ # convert phone type identifiers to generic threepids
+ if identifier_dict["type"] == "m.id.phone":
+ identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+ # convert threepid identifiers to user IDs
+ if identifier_dict["type"] == "m.id.thirdparty":
+ address = identifier_dict.get("address")
+ medium = identifier_dict.get("medium")
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ (medium, address), update=False
+ )
+ # Check for login providers that support 3pid login types
+ if login_type == LoginType.PASSWORD:
+ # we've already checked that there is a (valid) password field
+ assert isinstance(password, str)
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.check_password_provider_3pid(medium, address, password)
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ return canonical_user_id, callback_3pid
+
+ # No password providers were able to handle this 3pid
+ # Check local store
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
+ )
+ if not user_id:
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ (medium, address)
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ identifier_dict = {"type": "m.id.user", "user": user_id}
+
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier_dict["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+
+ username = identifier_dict.get("user")
+ if not username:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
+ )
+
+ try:
+ return await self._validate_userid_login(username, login_submission)
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower()
+ )
+ raise
+
+ async def _validate_userid_login(
+ self, username: str, login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ """Helper for validate_login
+
+ Handles login, once we've mapped 3pids onto userids
+
+ Args:
+ username: the username, from the identifier dict
+ login_submission: the whole of the login submission
+ (including 'type' and other relevant fields)
+ Returns:
+ A tuple of the canonical user id, and optional callback
+ to be called once the access token and device id are issued
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
+ known_login_type = False
+
+ for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -899,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
+ # we've already checked that there is a (valid) password field
+ password = login_submission["password"]
+ assert isinstance(password, str)
+
canonical_user_id = await self._check_local_password(
- qualified_user_id, password # type: ignore
+ qualified_user_id, password
)
if canonical_user_id:
@@ -938,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -1008,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1046,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1374,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..9b3c6b4551 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
- self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ self.hs.get_clock().time_msec()
+ + self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 78c4e94a9d..55c4377890 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -39,7 +39,7 @@ from synapse.handlers._base import BaseHandler
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
-from synapse.types import JsonDict, map_username_to_mxid_localpart
+from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
@@ -898,13 +898,39 @@ class OidcHandler(BaseHandler):
return UserAttributes(**attributes)
+ async def grandfather_existing_users() -> Optional[str]:
+ if self._allow_existing_users:
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
+ )
+
+ return previously_registered_user_id
+
+ return None
+
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
oidc_response_to_user_attributes,
- self._allow_existing_users,
+ grandfather_existing_users,
)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 426b58da9e..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -299,17 +299,22 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id: str) -> None:
- """Purge the given room from the database"""
+ async def purge_room(self, room_id: str, force: bool = False) -> None:
+ """Purge the given room from the database.
+
+ Args:
+ room_id: room to be purged
+ force: set true to skip checking for joined users.
+ """
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
- joined = await self.store.is_host_joined(room_id, self._server_name)
-
- if joined:
- raise SynapseError(400, "Users are still joined to this room")
+ if not force:
+ joined = await self.store.is_host_joined(room_id, self._server_name)
+ if joined:
+ raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4e693a419e..4d8ffe8821 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -366,7 +366,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles or requester.shadow_banned:
+ # allow the server notices mxid to set room-level profile
+ is_requester_server_notices_user = (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ )
+
+ if (
+ not self.allow_per_room_profiles and not is_requester_server_notices_user
+ ) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 34db10ffe4..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -265,10 +265,10 @@ class SamlHandler(BaseHandler):
return UserAttributes(
localpart=result.get("mxid_localpart"),
display_name=result.get("displayname"),
- emails=result.get("emails"),
+ emails=result.get("emails", []),
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
+ async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -290,17 +290,18 @@ class SamlHandler(BaseHandler):
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self.store.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
- )
return registered_user_id
+ return None
+
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
return await self._sso_handler.get_mxid_from_sso(
self._auth_provider_id,
remote_user_id,
user_agent,
ip_address,
saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
)
def expire_sessions(self):
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d963082210..f42b90e1bc 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -116,7 +116,7 @@ class SsoHandler(BaseHandler):
user_agent: str,
ip_address: str,
sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
- allow_existing_users: bool = False,
+ grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
) -> str:
"""
Given an SSO ID, retrieve the user ID for it and possibly register the user.
@@ -125,6 +125,10 @@ class SsoHandler(BaseHandler):
if it has that matrix ID is returned regardless of the current mapping
logic.
+ If a callable is provided for grandfathering users, it is called and can
+ potentially return a matrix ID to use. If it does, the SSO ID is linked to
+ this matrix ID for subsequent calls.
+
The mapping function is called (potentially multiple times) to generate
a localpart for the user.
@@ -132,17 +136,6 @@ class SsoHandler(BaseHandler):
given user-agent and IP address and the SSO ID is linked to this matrix
ID for subsequent calls.
- If allow_existing_users is true the mapping function is only called once
- and results in:
-
- 1. The use of a previously registered matrix ID. In this case, the
- SSO ID is linked to the matrix ID. (Note it is possible that
- other SSO IDs are linked to the same matrix ID.)
- 2. An unused localpart, in which case the user is registered (as
- discussed above).
- 3. An error if the generated localpart matches multiple pre-existing
- matrix IDs. Generally this should not happen.
-
Args:
auth_provider_id: A unique identifier for this SSO provider, e.g.
"oidc" or "saml".
@@ -152,8 +145,9 @@ class SsoHandler(BaseHandler):
sso_to_matrix_id_mapper: A callable to generate the user attributes.
The only parameter is an integer which represents the amount of
times the returned mxid localpart mapping has failed.
- allow_existing_users: True if the localpart returned from the
- mapping provider can be linked to an existing matrix ID.
+ grandfather_existing_users: A callable which can return an previously
+ existing matrix ID. The SSO ID is then linked to the returned
+ matrix ID.
Returns:
The user ID associated with the SSO response.
@@ -171,6 +165,16 @@ class SsoHandler(BaseHandler):
if previously_registered_user_id:
return previously_registered_user_id
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
# Otherwise, generate a new user.
for i in range(self._MAP_USERNAME_RETRIES):
try:
@@ -194,33 +198,7 @@ class SsoHandler(BaseHandler):
# Check if this mxid already exists
user_id = UserID(attributes.localpart, self.server_name).to_string()
- users = await self.store.get_users_by_id_case_insensitive(user_id)
- # Note, if allow_existing_users is true then the loop is guaranteed
- # to end on the first iteration: either by matching an existing user,
- # raising an error, or registering a new user. See the docstring for
- # more in-depth an explanation.
- if users and allow_existing_users:
- # If an existing matrix ID is returned, then use it.
- if len(users) == 1:
- previously_registered_user_id = next(iter(users))
- elif user_id in users:
- previously_registered_user_id = user_id
- else:
- # Do not attempt to continue generating Matrix IDs.
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, users
- )
- )
-
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
- )
-
- return previously_registered_user_id
-
- elif not users:
+ if not await self.store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
diff --git a/synapse/http/server.py b/synapse/http/server.py
index c0919f8cb7..6a4e429a6c 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
-from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request,
- error_code,
- error_dict,
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
+ request, error_code, error_dict, send_cors=True,
)
@@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
- pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
- pretty_print: Whether to include indentation and line-breaks in the
- resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@@ -615,13 +607,10 @@ def respond_with_json(
)
return None
- if pretty_print:
- encoder = iterencode_pretty_printed_json
+ if canonical_json:
+ encoder = iterencode_canonical_json
else:
- if canonical_json:
- encoder = iterencode_canonical_json
- else:
- encoder = _encode_json_bytes
+ encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
-
-
-def _request_user_agent_is_curl(request: Request) -> bool:
- user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
- for user_agent in user_agents:
- if b"curl" in user_agent:
- return True
- return False
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index afdf0bf131..d011e0aced 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
+ self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@@ -140,7 +141,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
await self._send_badge(badge)
def on_timer(self):
@@ -287,7 +292,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..6e7c880dc0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
+from synapse.storage.databases.main import DataStore
-async def get_badge_count(store, user_id):
+async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ if notifs["notify_count"] == 0:
+ continue
+
+ if group_by_room:
+ # return one badge count per conversation
+ badge += 1
+ else:
+ # increment the badge count by the number of unread messages in the room
+ badge += notifs["notify_count"]
return badge
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index f0c37eaf5e..84e002f934 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload(
- requester, room_id, user_id, remote_room_hosts, content
- ):
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ) -> JsonDict:
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and join via
- content(dict): The event content to use for the join event
+ requester: The user making the request according to the access token
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ remote_room_hosts: Servers to try and join via
+ content: The event content to use for the join event
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
- ):
+ ) -> JsonDict:
"""
Args:
- invite_event_id: ID of the invite to be rejected
- txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ invite_event_id: The ID of the invite to be rejected.
+ txn_id: Optional transaction ID supplied by the client
+ requester: User making the rejection request, according to the access token
+ content: Additional content to include in the rejection event.
Normally an empty dict.
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, invite_event_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, invite_event_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- async def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload( # type: ignore
+ room_id: str, user_id: str, change: str
+ ) -> JsonDict:
"""
Args:
- room_id (str)
- user_id (str)
- change (str): "left"
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ change: "left"
+
+ Returns:
+ A dict representing the payload of the request.
"""
assert change == "left"
return {}
- def _handle_request(self, request, room_id, user_id, change):
+ def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str, change: str
+ ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 353151169a..25f89e4685 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -70,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server. It is a combination and improvement of
- shut down and purge room.
+ """Delete a room from server.
+
+ It is a combination and improvement of shutdown and purge room.
+
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
+
If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto
+ created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
- It will remove all trace of a room from the database.
+
+ If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@@ -110,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -121,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
- await self.pagination_handler.purge_room(room_id)
+ await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 94452fcbf5..d7ae148214 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.handlers.auth import (
- convert_client_dict_legacy_fields_to_identifier,
- login_id_phone_to_thirdparty,
-)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
-from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
- self._failed_attempts_ratelimiter = Ratelimiter(
- clock=hs.get_clock(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- )
def on_GET(self, request: SynapseRequest):
flows = []
@@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- def _get_qualified_user_id(self, identifier):
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- return identifier["user"]
- else:
- return UserID(identifier["user"], self.hs.hostname).to_string()
-
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
- logger.info(
- "Got appservice login request with identifier: %r",
- login_submission.get("identifier"),
- )
+ identifier = login_submission.get("identifier")
+ logger.info("Got appservice login request with identifier: %r", identifier)
+
+ if not isinstance(identifier, dict):
+ raise SynapseError(
+ 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
+ )
+
+ # this login flow only supports identifiers of type "m.id.user".
+ if identifier.get("type") != "m.id.user":
+ raise SynapseError(
+ 400, "Unknown login identifier type", Codes.INVALID_PARAM
+ )
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
- qualified_user_id = self._get_qualified_user_id(identifier)
+ user = identifier.get("user")
+ if not isinstance(user, str):
+ raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
+
+ if user.startswith("@"):
+ qualified_user_id = user
+ else:
+ qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
-
- # convert phone type identifiers to generic threepids
- if identifier["type"] == "m.id.phone":
- identifier = login_id_phone_to_thirdparty(identifier)
-
- # convert threepid identifiers to user IDs
- if identifier["type"] == "m.id.thirdparty":
- address = identifier.get("address")
- medium = identifier.get("medium")
-
- if medium is None or address is None:
- raise SynapseError(400, "Invalid thirdparty identifier")
-
- # For emails, canonicalise the address.
- # We store all email addresses canonicalised in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- if medium == "email":
- try:
- address = canonicalise_email(address)
- except ValueError as e:
- raise SynapseError(400, str(e))
-
- # We also apply account rate limiting using the 3PID as a key, as
- # otherwise using 3PID bypasses the ratelimiting based on user ID.
- self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
-
- # Check for login providers that support 3pid login types
- (
- canonical_user_id,
- callback_3pid,
- ) = await self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
- if canonical_user_id:
- # Authentication through password provider and 3pid succeeded
-
- result = await self._complete_login(
- canonical_user_id, login_submission, callback_3pid
- )
- return result
-
- # No password providers were able to handle this 3pid
- # Check local store
- user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- medium, address
- )
- if not user_id:
- logger.warning(
- "unknown 3pid identifier medium %s, address %r", medium, address
- )
- # We mark that we've failed to log in here, as
- # `check_password_provider_3pid` might have returned `None` due
- # to an incorrect password, rather than the account not
- # existing.
- #
- # If it returned None but the 3PID was bound then we won't hit
- # this code path, which is fine as then the per-user ratelimit
- # will kick in below.
- self._failed_attempts_ratelimiter.can_do_action((medium, address))
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- identifier = {"type": "m.id.user", "user": user_id}
-
- # by this point, the identifier should be an m.id.user: if it's anything
- # else, we haven't understood it.
- qualified_user_id = self._get_qualified_user_id(identifier)
-
- # Check if we've hit the failed ratelimit (but don't update it)
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ login_submission, ratelimit=True
)
-
- try:
- canonical_user_id, callback = await self.auth_handler.validate_login(
- identifier["user"], login_submission
- )
- except LoginError:
- # The user has failed to log in, so we need to update the rate
- # limiter. Using `can_do_action` avoids us raising a ratelimit
- # exception and masking the LoginError. The actual ratelimiting
- # should have happened above.
- self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
- raise
-
result = await self._complete_login(
canonical_user_id, login_submission, callback
)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a54e1011f7..eebee44a44 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ea68114026..a89ae6ddf9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
- self.clock = hs.clock
+ self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
diff --git a/synapse/server.py b/synapse/server.py
index c82d8f9fad..b017e3489f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -147,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
- depname = builder.__name__[len("get_") :]
+ # get_attr -> _attr
+ depname = builder.__name__[len("get") :]
building = [False]
@@ -235,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
- self.clock = Clock(reactor)
- self.distributor = Distributor()
-
- self.registration_ratelimiter = Ratelimiter(
- clock=self.clock,
- rate_hz=config.rc_registration.per_second,
- burst_count=config.rc_registration.burst_count,
- )
-
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@@ -301,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
+ @cache_in_self
def get_clock(self) -> Clock:
- return self.clock
+ return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@@ -319,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
+ @cache_in_self
def get_distributor(self) -> Distributor:
- return self.distributor
+ return Distributor()
+ @cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
- return self.registration_ratelimiter
+ return Ratelimiter(
+ clock=self.get_clock(),
+ rate_hz=self.config.rc_registration.per_second,
+ burst_count=self.config.rc_registration.burst_count,
+ )
@cache_in_self
def get_federation_client(self) -> FederationClient:
@@ -687,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
- return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+ return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
+ "event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
- "event_json",
"event_push_actions",
"event_search",
"events",
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644
--- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('local_media_repository_thumbnails_method_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644
--- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('populate_stats_process_rooms_2', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5812, 'populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
index a2842687f1..e1a35be831 100644
--- a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -1,2 +1,2 @@
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('users_have_local_media', '{}');
\ No newline at end of file
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5822, 'users_have_local_media', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
index 61c558db77..75c3915a94 100644
--- a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -13,5 +13,5 @@
* limitations under the License.
*/
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('e2e_cross_signing_keys_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5823, 'e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index b5055e018c..e24ce81284 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -52,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.clock.now = 5000
+ self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -78,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.clock.now = 1000
+ self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@@ -87,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.clock.now = 6000
+ self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e880d32be6..d485af52fd 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -23,7 +23,7 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from synapse.handlers.oidc_handler import OidcError, OidcHandler, OidcMappingProvider
+from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
@@ -127,13 +127,8 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
-
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = "Synapse Test"
-
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["public_baseurl"] = BASE_URL
oidc_config = {
"enabled": True,
@@ -149,19 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
- hs = self.setup_test_homeserver(
- http_client=self.http_client,
- proxied_http_client=self.http_client,
- config=config,
- )
+ return config
+
+ def make_homeserver(self, reactor, clock):
- self.handler = OidcHandler(hs)
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.handler = hs.get_oidc_handler()
+ sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- self.handler._sso_handler.render_error = self.render_error
+ sso_handler.render_error = self.render_error
# Reduce the number of attempts when generating MXIDs.
- self.handler._sso_handler._MAP_USERNAME_RETRIES = 3
+ sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@@ -731,6 +731,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
self.assertEqual(mxid, "@test_user:test")
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
# in this case we'll just use the same username. A more realistic example
@@ -832,7 +840,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
self.assertEqual(mxid, "@test_user1:test")
- # Register all of the potential users for a particular username.
+ # Register all of the potential mxids for a particular OIDC username.
self.get_success(
store.register_user(user_id="@tester:test", password_hash=None)
)
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
new file mode 100644
index 0000000000..ceaf0902d2
--- /dev/null
+++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,580 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the password_auth_provider interface"""
+
+from typing import Any, Type, Union
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
+# a mock instance which the dummy auth providers delegate to, so we can see what's going
+# on
+mock_password_provider = Mock()
+
+
+class PasswordOnlyAuthProvider:
+ """A password_provider which only implements `check_password`."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def check_password(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+class CustomAuthProvider:
+ """A password_provider which implements a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class PasswordCustomAuthProvider:
+ """A password_provider which implements password login via `check_auth`, as well
+ as a custom type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given password auth providers"""
+ return {
+ "password_providers": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
+class PasswordAuthProviderTests(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def setUp(self):
+ # we use a global mock device, so make sure we are starting with a clean slate
+ mock_password_provider.reset_mock()
+ super().setUp()
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_login(self):
+ # login flows should only have m.login.password
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # login with mxid should work too
+ channel = self._send_password_login("@u:bz", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:bz", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ mock_password_provider.reset_mock()
+
+ # try a weird username / pass. Honestly it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with(
+ "@ USER🙂NAME :test", " pASS😢word "
+ )
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth(self):
+ """UI Auth should delegate correctly to the password provider"""
+
+ # create the user, otherwise access doesn't work
+ module_api = self.hs.get_module_api()
+ self.get_success(module_api.register_user("u"))
+
+ # log in twice, to get two devices
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ tok1 = self.login("u", "p")
+ self.login("u", "p", device_id="dev2")
+ mock_password_provider.reset_mock()
+
+ # have the auth provider deny the request to start with
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # make the initial request which returns a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Make another request providing the UI auth flow.
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # Finally, check the request goes through when we allow it
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_login(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@localuser:test", channel.json_body["user_id"])
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # have the auth provider deny the request
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Wrong password
+ channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "xxx"
+ )
+ mock_password_provider.reset_mock()
+
+ # Right password
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_login(self):
+ """localdb_enabled can block login with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_ui_auth(self):
+ """localdb_enabled can block ui auth with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # allow login via the auth provider
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "p")
+ self.login("localuser", "p", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # m.login.password UIA is permitted because the auth provider allows it,
+ # even though the localdb does not.
+ self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
+ session = channel.json_body["session"]
+ mock_password_provider.check_password.assert_not_called()
+
+ # now try deleting with the local password
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_auth_disabled(self):
+ """password auth doesn't work if it's disabled across the board"""
+ # login flows should be empty
+ flows = self._get_login_flows()
+ self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_login(self):
+ # login flows should have the custom flow and m.login.password, since we
+ # haven't disabled local password lookup.
+ # (password must come first, because reasons)
+ flows = self._get_login_flows()
+ self.assertEqual(
+ flows,
+ [{"type": "m.login.password"}, {"type": "test.login_type"}]
+ + ADDITIONAL_LOGIN_FLOWS,
+ )
+
+ # login with missing param should be rejected
+ channel = self._send_login("test.login_type", "u")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+ mock_password_provider.reset_mock()
+
+ # try a weird username. Again, it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@ MALFORMED! :bz"
+ )
+ channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_ui_auth(self):
+ # register the user and log in twice, to get two devices
+ self.register_user("localuser", "localpass")
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ # missing param
+ body = {
+ "auth": {
+ "type": "test.login_type",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
+ # use it...
+ self.assertIn("Missing parameters", channel.json_body["error"])
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # right params, but authing as the wrong user
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ body["auth"]["test_field"] = "foo"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+ mock_password_provider.reset_mock()
+
+ # and finally, succeed
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_callback(self):
+ callback = Mock(return_value=defer.succeed(None))
+
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", callback)
+ )
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+
+ # check the args to the callback
+ callback.assert_called_once()
+ call_args, call_kwargs = callback.call_args
+ # should be one positional arg
+ self.assertEqual(len(call_args), 1)
+ self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ for p in ["user_id", "access_token", "device_id", "home_server"]:
+ self.assertIn(p, call_args[0])
+
+ @override_config(
+ {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
+ )
+ def test_custom_auth_password_disabled(self):
+ """Test login with a custom auth provider where password login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login(self):
+ """log in with a custom auth provider which implements password, but password
+ login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth(self):
+ """UI Auth with a custom auth provider which implements password, but password
+ login is disabled"""
+ # register the user and log in twice via the test login type to get two devices,
+ self.register_user("localuser", "localpass")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._send_login("test.login_type", "localuser", test_field="")
+ self.assertEqual(channel.code, 200, channel.result)
+ tok1 = channel.json_body["access_token"]
+
+ channel = self._send_login(
+ "test.login_type", "localuser", test_field="", device_id="dev2"
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected. In particular, "password" should *not*
+ # be present.
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ mock_password_provider.reset_mock()
+
+ # check that auth with password is rejected
+ body = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "password": "localpass",
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ "Password login has been disabled.", channel.json_body["error"]
+ )
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # successful auth
+ body["auth"]["type"] = "test.login_type"
+ body["auth"]["test_field"] = "x"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "x"}
+ )
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback(self):
+ """Test login with a custom auth provider where the local db is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # password login shouldn't work and should be rejected with a 400
+ # ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_login_flows(self) -> JsonDict:
+ _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["flows"]
+
+ def _send_password_login(self, user: str, password: str) -> FakeChannel:
+ return self._send_login(type="m.login.password", user=user, password=password)
+
+ def _send_login(self, type, user, **params) -> FakeChannel:
+ params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ return channel
+
+ def _start_delete_device_session(self, access_token, device_id) -> str:
+ """Make an initial delete device request, and return the UI Auth session ID"""
+ channel = self._delete_device(access_token, device_id)
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ return channel.json_body["session"]
+
+ def _authed_delete_device(
+ self,
+ access_token: str,
+ device_id: str,
+ session: str,
+ user_id: str,
+ password: str,
+ ) -> FakeChannel:
+ """Make a delete device request, authenticating with the given uid/password"""
+ return self._delete_device(
+ access_token,
+ device_id,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": user_id},
+ "password": password,
+ "session": session,
+ },
+ },
+ )
+
+ def _delete_device(
+ self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ _, channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token
+ )
+ return channel
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
new file mode 100644
index 0000000000..e1e13a5faf
--- /dev/null
+++ b/tests/handlers/test_saml.py
@@ -0,0 +1,168 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import attr
+
+from synapse.handlers.sso import MappingException
+
+from tests.unittest import HomeserverTestCase, override_config
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+
+
+@attr.s
+class FakeAuthnResponse:
+ ava = attr.ib(type=dict)
+
+
+class TestMappingProvider:
+ def __init__(self, config, module):
+ pass
+
+ @staticmethod
+ def parse_config(config):
+ return
+
+ @staticmethod
+ def get_saml_attributes(config):
+ return {"uid"}, {"displayName"}
+
+ def get_remote_user_id(self, saml_response, client_redirect_url):
+ return saml_response.ava["uid"]
+
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ localpart = saml_response.ava["username"] + (str(failures) if failures else "")
+ return {"mxid_localpart": localpart, "displayname": None}
+
+
+class SamlHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ saml_config = {
+ "sp_config": {"metadata": {}},
+ # Disable grandfathering.
+ "grandfathered_mxid_source_attribute": None,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ saml_config.update(config.get("saml2_config", {}))
+ config["saml2_config"] = saml_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_saml_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_saml_response_to_user(self):
+ """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ # The redirect_url doesn't matter with the default user mapping provider.
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+ def test_map_saml_response_to_existing_user(self):
+ """Existing users can log in with SAML account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # Map a user via SSO.
+ saml_response = FakeAuthnResponse(
+ {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+ )
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ def test_map_saml_response_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
+ redirect_url = ""
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ def test_map_saml_response_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential mxids for a particular SAML username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
diff --git a/tests/push/test_http.py b/tests/push/test_http.py
index 8571924b29..f118430309 100644
--- a/tests/push/test_http.py
+++ b/tests/push/test_http.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from mock import Mock
from twisted.internet.defer import Deferred
@@ -20,8 +19,9 @@ from twisted.internet.defer import Deferred
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.rest.client.v1 import login, room
+from synapse.rest.client.v2_alpha import receipts
-from tests.unittest import HomeserverTestCase
+from tests.unittest import HomeserverTestCase, override_config
class HTTPPusherTests(HomeserverTestCase):
@@ -29,6 +29,7 @@ class HTTPPusherTests(HomeserverTestCase):
synapse.rest.admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
+ receipts.register_servlets,
]
user_id = True
hijack_auth = False
@@ -499,3 +500,161 @@ class HTTPPusherTests(HomeserverTestCase):
# check that this is low-priority
self.assertEqual(self.push_attempts[1][2]["notification"]["prio"], "low")
+
+ def test_push_unread_count_group_by_room(self):
+ """
+ The HTTP pusher will group unread count by number of unread rooms.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # This push should still only contain an unread count of 1 (for 1 unread room)
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 1
+ )
+
+ @override_config({"push": {"group_unread_count_by_room": False}})
+ def test_push_unread_count_message_count(self):
+ """
+ The HTTP pusher will send the total unread message count.
+ """
+ # Carry out common push count tests and setup
+ self._test_push_unread_count()
+
+ # Carry out our option-value specific test
+ #
+ # We're counting every unread message, so there should now be 4 since the
+ # last read receipt
+ self.assertEqual(
+ self.push_attempts[5][2]["notification"]["counts"]["unread"], 4
+ )
+
+ def _test_push_unread_count(self):
+ """
+ Tests that the correct unread count appears in sent push notifications
+
+ Note that:
+ * Sending messages will cause push notifications to go out to relevant users
+ * Sending a read receipt will cause a "badge update" notification to go out to
+ the user that sent the receipt
+ """
+ # Register the user who gets notified
+ user_id = self.register_user("user", "pass")
+ access_token = self.login("user", "pass")
+
+ # Register the user who sends the message
+ other_user_id = self.register_user("other_user", "pass")
+ other_access_token = self.login("other_user", "pass")
+
+ # Create a room (as other_user)
+ room_id = self.helper.create_room_as(other_user_id, tok=other_access_token)
+
+ # The user to get notified joins
+ self.helper.join(room=room_id, user=user_id, tok=access_token)
+
+ # Register the pusher
+ user_tuple = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_tuple.token_id
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "example.com"},
+ )
+ )
+
+ # Send a message
+ response = self.helper.send(
+ room_id, body="Hello there!", tok=other_access_token
+ )
+ # To get an unread count, the user who is getting notified has to have a read
+ # position in the room. We'll set the read position to this event in a moment
+ first_message_event_id = response["event_id"]
+
+ # Advance time a bit (so the pusher will register something has happened) and
+ # make the push succeed
+ self.push_attempts[0][0].callback({})
+ self.pump()
+
+ # Check our push made it
+ self.assertEqual(len(self.push_attempts), 1)
+ self.assertEqual(self.push_attempts[0][1], "example.com")
+
+ # Check that the unread count for the room is 0
+ #
+ # The unread count is zero as the user has no read receipt in the room yet
+ self.assertEqual(
+ self.push_attempts[0][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Now set the user's read receipt position to the first event
+ #
+ # This will actually trigger a new notification to be sent out so that
+ # even if the user does not receive another message, their unread
+ # count goes down
+ request, channel = self.make_request(
+ "POST",
+ "/rooms/%s/receipt/m.read/%s" % (room_id, first_message_event_id),
+ {},
+ access_token=access_token,
+ )
+ self.assertEqual(channel.code, 200, channel.json_body)
+
+ # Advance time and make the push succeed
+ self.push_attempts[1][0].callback({})
+ self.pump()
+
+ # Unread count is still zero as we've read the only message in the room
+ self.assertEqual(len(self.push_attempts), 2)
+ self.assertEqual(
+ self.push_attempts[1][2]["notification"]["counts"]["unread"], 0
+ )
+
+ # Send another message
+ self.helper.send(
+ room_id, body="How's the weather today?", tok=other_access_token
+ )
+
+ # Advance time and make the push succeed
+ self.push_attempts[2][0].callback({})
+ self.pump()
+
+ # This push should contain an unread count of 1 as there's now been one
+ # message since our last read receipt
+ self.assertEqual(len(self.push_attempts), 3)
+ self.assertEqual(
+ self.push_attempts[2][2]["notification"]["counts"]["unread"], 1
+ )
+
+ # Since we're grouping by room, sending more messages shouldn't increase the
+ # unread count, as they're all being sent in the same room
+ self.helper.send(room_id, body="Hello?", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[3][0].callback({})
+
+ self.helper.send(room_id, body="Hello??", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[4][0].callback({})
+
+ self.helper.send(room_id, body="HELLO???", tok=other_access_token)
+
+ # Advance time and make the push succeed
+ self.pump()
+ self.push_attempts[5][0].callback({})
+
+ self.assertEqual(len(self.push_attempts), 6)
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 516db4c30a..295c5d58a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -78,7 +78,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
- self.worker_hs.replication_data_handler = self.test_handler
+ self.worker_hs._replication_data_handler = self.test_handler
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py
index 2a65ab33bd..dadf9db660 100644
--- a/tests/rest/admin/test_media.py
+++ b/tests/rest/admin/test_media.py
@@ -192,7 +192,6 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase):
self.handler = hs.get_device_handler()
self.media_repo = hs.get_media_repository_resource()
self.server_name = hs.hostname
- self.clock = hs.clock
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py
index b84f86d28c..5d5c24d01c 100644
--- a/tests/rest/client/v1/test_presence.py
+++ b/tests/rest/client/v1/test_presence.py
@@ -33,13 +33,16 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ presence_handler = Mock()
+ presence_handler.set_state.return_value = defer.succeed(None)
+
hs = self.setup_test_homeserver(
- "red", http_client=None, federation_client=Mock()
+ "red",
+ http_client=None,
+ federation_client=Mock(),
+ presence_handler=presence_handler,
)
- hs.presence_handler = Mock()
- hs.presence_handler.set_state.return_value = defer.succeed(None)
-
return hs
def test_put_presence(self):
@@ -55,7 +58,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 1)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 1)
def test_put_presence_disabled(self):
"""
@@ -70,4 +73,4 @@ class PresenceTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 200)
- self.assertEqual(self.hs.presence_handler.set_state.call_count, 0)
+ self.assertEqual(self.hs.get_presence_handler().set_state.call_count, 0)
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index b58768675b..737c38c396 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -41,14 +41,37 @@ class RestHelper:
auth_user_id = attr.ib()
def create_room_as(
- self, room_creator=None, is_public=True, tok=None, expect_code=200,
- ):
+ self,
+ room_creator: str = None,
+ is_public: bool = True,
+ room_version: str = None,
+ tok: str = None,
+ expect_code: int = 200,
+ ) -> str:
+ """
+ Create a room.
+
+ Args:
+ room_creator: The user ID to create the room with.
+ is_public: If True, the `visibility` parameter will be set to the
+ default (public). Otherwise, the `visibility` parameter will be set
+ to "private".
+ room_version: The room version to create the room as. Defaults to Synapse's
+ default room version.
+ tok: The access token to use in the request.
+ expect_code: The expected HTTP response code.
+
+ Returns:
+ The ID of the newly created room.
+ """
temp_id = self.auth_user_id
self.auth_user_id = room_creator
path = "/_matrix/client/r0/createRoom"
content = {}
if not is_public:
content["visibility"] = "private"
+ if room_version:
+ content["room_version"] = room_version
if tok:
path = path + "?access_token=%s" % tok
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index f684c37db5..77246e478f 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -38,11 +38,6 @@ class DummyRecaptchaChecker(UserInteractiveAuthChecker):
return succeed(True)
-class DummyPasswordChecker(UserInteractiveAuthChecker):
- def check_auth(self, authdict, clientip):
- return succeed(authdict["identifier"]["user"])
-
-
class FallbackAuthTests(unittest.HomeserverTestCase):
servlets = [
@@ -162,9 +157,6 @@ class UIAuthTests(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.PASSWORD] = DummyPasswordChecker(hs)
-
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
@@ -234,6 +226,31 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
)
+ def test_grandfathered_identifier(self):
+ """Check behaviour without "identifier" dict
+
+ Synapse used to require clients to submit a "user" field for m.login.password
+ UIA - check that still works.
+ """
+
+ device_id = self.get_device_ids()[0]
+ channel = self.delete_device(device_id, 401)
+ session = channel.json_body["session"]
+
+ # Make another request providing the UI auth flow.
+ self.delete_device(
+ device_id,
+ 200,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "user": self.user,
+ "password": self.user_pass,
+ "session": session,
+ },
+ },
+ )
+
def test_can_change_body(self):
"""
The client dict can be modified during the user interactive authentication session.
diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 699a40c3df..8f0c2430e8 100644
--- a/tests/rest/client/v2_alpha/test_register.py
+++ b/tests/rest/client/v2_alpha/test_register.py
@@ -569,7 +569,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
tok = self.login("kermit", "monkey")
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -587,7 +587,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# We need to manually add an email address otherwise the handler will do
# nothing.
- now = self.hs.clock.time_msec()
+ now = self.hs.get_clock().time_msec()
self.get_success(
self.store.user_add_threepid(
user_id=user_id,
@@ -646,7 +646,7 @@ class AccountValidityBackgroundJobTestCase(unittest.HomeserverTestCase):
self.hs.config.account_validity.startup_job_max_delta = self.max_delta
- now_ms = self.hs.clock.time_msec()
+ now_ms = self.hs.get_clock().time_msec()
self.get_success(self.store._set_expiration_date_when_missing())
res = self.get_success(self.store.get_expiration_ts_for_user(user_id))
diff --git a/tests/utils.py b/tests/utils.py
index acec74e9e9..c8d3ffbaba 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -271,7 +271,7 @@ def setup_test_homeserver(
# Install @cache_in_self attributes
for key, val in kwargs.items():
- setattr(hs, key, val)
+ setattr(hs, "_" + key, val)
# Mock TLS
hs.tls_server_context_factory = Mock()
|