diff --git a/synapse/__init__.py b/synapse/__init__.py
index c38a8f613d..f2d3ac68eb 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.23.1"
+__version__ = "1.24.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/_scripts/register_new_matrix_user.py b/synapse/_scripts/register_new_matrix_user.py
index da0996edbc..dfe26dea6d 100644
--- a/synapse/_scripts/register_new_matrix_user.py
+++ b/synapse/_scripts/register_new_matrix_user.py
@@ -37,7 +37,7 @@ def request_registration(
exit=sys.exit,
):
- url = "%s/_matrix/client/r0/admin/register" % (server_location,)
+ url = "%s/_synapse/admin/v1/register" % (server_location.rstrip("/"),)
# Get the nonce
r = requests.get(url, verify=False)
diff --git a/synapse/api/auth_blocking.py b/synapse/api/auth_blocking.py
index d8fafd7cb8..9c227218e0 100644
--- a/synapse/api/auth_blocking.py
+++ b/synapse/api/auth_blocking.py
@@ -14,10 +14,12 @@
# limitations under the License.
import logging
+from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved
+from synapse.types import Requester
logger = logging.getLogger(__name__)
@@ -33,24 +35,47 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
+ self._server_name = hs.hostname
- async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
+ async def check_auth_blocking(
+ self,
+ user_id: Optional[str] = None,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ requester: Optional[Requester] = None,
+ ):
"""Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag
Args:
- user_id(str|None): If present, checks for presence against existing
+ user_id: If present, checks for presence against existing
MAU cohort
- threepid(dict|None): If present, checks for presence against configured
+ threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and
threepid should never be set at the same time.
- user_type(str|None): If present, is used to decide whether to check against
+ user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU.
+
+ requester: If present, and the authenticated entity is a user, checks for
+ presence against existing MAU cohort. Passing in both a `user_id` and
+ `requester` is an error.
"""
+ if requester and user_id:
+ raise Exception(
+ "Passed in both 'user_id' and 'requester' to 'check_auth_blocking'"
+ )
+
+ if requester:
+ if requester.authenticated_entity.startswith("@"):
+ user_id = requester.authenticated_entity
+ elif requester.authenticated_entity == self._server_name:
+ # We never block the server from doing actions on behalf of
+ # users.
+ return
# Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 9c8dc785c6..895b38ae76 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -32,6 +32,7 @@ from synapse.app.phone_stats_home import start_phone_stats_home
from synapse.config.server import ListenerConfig
from synapse.crypto import context_factory
from synapse.logging.context import PreserveLoggingContext
+from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.util.async_helpers import Linearizer
from synapse.util.daemonize import daemonize_process
from synapse.util.rlimit import change_resource_limit
@@ -244,6 +245,7 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+ @wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
# we're not using systemd.
@@ -254,7 +256,13 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
sdnotify(b"READY=1")
- signal.signal(signal.SIGHUP, handle_sighup)
+ # We defer running the sighup handlers until next reactor tick. This
+ # is so that we're in a sane state, e.g. flushing the logs may fail
+ # if the sighup happens in the middle of writing a log entry.
+ def run_sighup(*args, **kwargs):
+ hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+
+ signal.signal(signal.SIGHUP, run_sighup)
register_sighup(refresh_certificate, hs)
diff --git a/synapse/config/push.py b/synapse/config/push.py
index a1f3752c8a..3adbfb73e6 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -21,8 +21,11 @@ class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
- push_config = config.get("push", {})
+ push_config = config.get("push") or {}
self.push_include_content = push_config.get("include_content", True)
+ self.push_group_unread_count_by_room = push_config.get(
+ "group_unread_count_by_room", True
+ )
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
@@ -49,18 +52,33 @@ class PushConfig(Config):
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """
- # Clients requesting push notifications can either have the body of
- # the message sent in the notification poke along with other details
- # like the sender, or just the event ID and room ID (`event_id_only`).
- # If clients choose the former, this option controls whether the
- # notification request includes the content of the event (other details
- # like the sender are still included). For `event_id_only` push, it
- # has no effect.
- #
- # For modern android devices the notification content will still appear
- # because it is loaded by the app. iPhone, however will send a
- # notification saying only that a message arrived and who it came from.
- #
- #push:
- # include_content: true
+ ## Push ##
+
+ push:
+ # Clients requesting push notifications can either have the body of
+ # the message sent in the notification poke along with other details
+ # like the sender, or just the event ID and room ID (`event_id_only`).
+ # If clients choose the former, this option controls whether the
+ # notification request includes the content of the event (other details
+ # like the sender are still included). For `event_id_only` push, it
+ # has no effect.
+ #
+ # For modern android devices the notification content will still appear
+ # because it is loaded by the app. iPhone, however will send a
+ # notification saying only that a message arrived and who it came from.
+ #
+ # The default value is "true" to include message details. Uncomment to only
+ # include the event ID and room ID in push notification payloads.
+ #
+ #include_content: false
+
+ # When a push notification is received, an unread count is also sent.
+ # This number can either be calculated as the number of unread messages
+ # for the user, or the number of *rooms* the user has unread messages in.
+ #
+ # The default value is "true", meaning push clients will see the number of
+ # rooms with unread messages in them. Uncomment to instead send the number
+ # of unread messages.
+ #
+ #group_unread_count_by_room: false
"""
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index b0a77a2e43..cc5f75123c 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -347,8 +347,9 @@ class RegistrationConfig(Config):
# email will be globally disabled.
#
# Additionally, if `msisdn` is not set, registration and password resets via msisdn
- # will be disabled regardless. This is due to Synapse currently not supporting any
- # method of sending SMS messages on its own.
+ # will be disabled regardless, and users will not be able to associate an msisdn
+ # identifier to their account. This is due to Synapse currently not supporting
+ # any method of sending SMS messages on its own.
#
# To enable using an identity server for operations regarding a particular third-party
# identifier type, set the value to the URL of that identity server as shown in the
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index 2ff7dfb311..c1b8e98ae0 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -90,6 +90,8 @@ class SAML2Config(Config):
"grandfathered_mxid_source_attribute", "uid"
)
+ self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
+
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
@@ -256,6 +258,12 @@ class SAML2Config(Config):
# remote:
# - url: https://our_idp/metadata.xml
+ # Allowed clock difference in seconds between the homeserver and IdP.
+ #
+ # Uncomment the below to increase the accepted time difference from 0 to 3 seconds.
+ #
+ #accepted_time_diff: 3
+
# By default, the user has to go to our login page first. If you'd like
# to allow IdP-initiated login, set 'allow_unsolicited: true' in a
# 'service.sp' section:
@@ -377,6 +385,14 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
+
+ # If the metadata XML contains multiple IdP entities then the `idp_entityid`
+ # option must be set to the entity to redirect users to.
+ #
+ # Most deployments only have a single IdP entity and so should omit this
+ # option.
+ #
+ #idp_entityid: 'https://our_idp/entityid'
""" % {
"config_dir_path": config_dir_path
}
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index bad18f7fdf..936896656a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,13 +15,12 @@
# limitations under the License.
import inspect
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.events
import synapse.server
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bd8e71ae56..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having
# homeserver.
- requester = synapse.types.create_requester(target_user, is_guest=True)
+ requester = synapse.types.create_requester(
+ target_user, is_guest=True, authenticated_entity=self.server_name
+ )
handler = self.hs.get_room_member_handler()
await handler.update_membership(
requester,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9fc8444228..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
new_token: Optional[int],
users: Collection[Union[str, UserID]],
):
- logger.info("Checking interested services for %s" % (stream_key))
+ logger.debug("Checking interested services for %s" % (stream_key))
with Measure(self.clock, "notify_interested_services_ephemeral"):
for service in services:
# Only handle typing if we have the latest token
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 213baea2e3..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Tuple,
Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
# better way to break the loop
account_handler = ModuleApi(hs, self)
- self.password_providers = []
- for module, config in hs.config.password_providers:
- try:
- self.password_providers.append(
- module(config=config, account_handler=account_handler)
- )
- except Exception as e:
- logger.error("Error while initializing %r: %s", module, e)
- raise
+ self.password_providers = [
+ PasswordProvider.load(module, config, account_handler)
+ for module, config in hs.config.password_providers
+ ]
- logger.info("Extra password_providers: %r", self.password_providers)
+ logger.info("Extra password_providers: %s", self.password_providers)
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
@@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
# type in the list. (NB that the spec doesn't require us to do so and
# clients which favour types that they don't understand over those that
# they do are technically broken)
+
+ # start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if self._password_enabled:
+ if hs.config.password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
+
for provider in self.password_providers:
if hasattr(provider, "get_supported_login_types"):
for t in provider.get_supported_login_types().keys():
if t not in login_types:
login_types.append(t)
+
+ if not self._password_enabled:
+ login_types.remove(LoginType.PASSWORD)
+
self._supported_login_types = login_types
+
# Login types and UI Auth types have a heavy overlap, but are not
# necessarily identical. Login types have SSO (and other login types)
# added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
)
+ # Ratelimitier for failed /login attempts
+ self._failed_login_attempts_ratelimiter = Ratelimiter(
+ clock=hs.get_clock(),
+ rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+ burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+ )
+
self._clock = self.hs.get_clock()
# Expire old UI auth sessions after a period of time.
@@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
res = await checker.check_auth(authdict, clientip=clientip)
return res
- # build a v1-login-style dict out of the authdict and fall back to the
- # v1 code
- user_id = authdict.get("user")
-
- if user_id is None:
- raise SynapseError(400, "", Codes.MISSING_PARAM)
-
- (canonical_id, callback) = await self.validate_login(user_id, authdict)
+ # fall back to the v1 login flow
+ canonical_id, _ = await self.validate_login(authdict)
return canonical_id
def _get_params_recaptcha(self) -> dict:
@@ -698,8 +704,12 @@ class AuthHandler(BaseHandler):
}
async def get_access_token_for_user_id(
- self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
+ ) -> str:
"""
Creates a new access token for the user with the given user ID.
@@ -725,13 +735,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
)
- logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+ if puppets_user_id:
+ logger.info(
+ "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+ )
+ else:
+ logger.info(
+ "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+ )
await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user(
- user_id, access_token, device_id, valid_until_ms
+ user_id=user_id,
+ token=access_token,
+ device_id=device_id,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=puppets_user_id,
)
# the device *should* have been registered before we got here; however,
@@ -808,17 +830,17 @@ class AuthHandler(BaseHandler):
return self._supported_login_types
async def validate_login(
- self, username: str, login_submission: Dict[str, Any]
+ self, login_submission: Dict[str, Any], ratelimit: bool = False,
) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
"""Authenticates the user for the /login API
- Also used by the user-interactive auth flow to validate
- m.login.password auth types.
+ Also used by the user-interactive auth flow to validate auth types which don't
+ have an explicit UIA handler, including m.password.auth.
Args:
- username: username supplied by the user
login_submission: the whole of the login submission
(including 'type' and other relevant fields)
+ ratelimit: whether to apply the failed_login_attempt ratelimiter
Returns:
A tuple of the canonical user id, and optional callback
to be called once the access token and device id are issued
@@ -827,38 +849,160 @@ class AuthHandler(BaseHandler):
SynapseError if there was a problem with the request
LoginError if there was an authentication problem.
"""
-
- if username.startswith("@"):
- qualified_user_id = username
- else:
- qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
login_type = login_submission.get("type")
- known_login_type = False
+ if not isinstance(login_type, str):
+ raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+ # ideally, we wouldn't be checking the identifier unless we know we have a login
+ # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+ #
+ # But the auth providers' check_auth interface requires a username, so in
+ # practice we can only support login methods which we can map to a username
+ # anyway.
# special case to check for "password" for the check_password interface
# for the auth providers
password = login_submission.get("password")
-
if login_type == LoginType.PASSWORD:
if not self._password_enabled:
raise SynapseError(400, "Password login has been disabled.")
- if not password:
- raise SynapseError(400, "Missing parameter: password")
+ if not isinstance(password, str):
+ raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
- for provider in self.password_providers:
- if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
- known_login_type = True
- is_valid = await provider.check_password(qualified_user_id, password)
- if is_valid:
- return qualified_user_id, None
+ # map old-school login fields into new-school "identifier" fields.
+ identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+ login_submission
+ )
- if not hasattr(provider, "get_supported_login_types") or not hasattr(
- provider, "check_auth"
- ):
- # this password provider doesn't understand custom login types
- continue
+ # convert phone type identifiers to generic threepids
+ if identifier_dict["type"] == "m.id.phone":
+ identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+ # convert threepid identifiers to user IDs
+ if identifier_dict["type"] == "m.id.thirdparty":
+ address = identifier_dict.get("address")
+ medium = identifier_dict.get("medium")
+
+ if medium is None or address is None:
+ raise SynapseError(400, "Invalid thirdparty identifier")
+
+ # For emails, canonicalise the address.
+ # We store all email addresses canonicalised in the DB.
+ # (See add_threepid in synapse/handlers/auth.py)
+ if medium == "email":
+ try:
+ address = canonicalise_email(address)
+ except ValueError as e:
+ raise SynapseError(400, str(e))
+
+ # We also apply account rate limiting using the 3PID as a key, as
+ # otherwise using 3PID bypasses the ratelimiting based on user ID.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ (medium, address), update=False
+ )
+
+ # Check for login providers that support 3pid login types
+ if login_type == LoginType.PASSWORD:
+ # we've already checked that there is a (valid) password field
+ assert isinstance(password, str)
+ (
+ canonical_user_id,
+ callback_3pid,
+ ) = await self.check_password_provider_3pid(medium, address, password)
+ if canonical_user_id:
+ # Authentication through password provider and 3pid succeeded
+ return canonical_user_id, callback_3pid
+
+ # No password providers were able to handle this 3pid
+ # Check local store
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+ medium, address
+ )
+ if not user_id:
+ logger.warning(
+ "unknown 3pid identifier medium %s, address %r", medium, address
+ )
+ # We mark that we've failed to log in here, as
+ # `check_password_provider_3pid` might have returned `None` due
+ # to an incorrect password, rather than the account not
+ # existing.
+ #
+ # If it returned None but the 3PID was bound then we won't hit
+ # this code path, which is fine as then the per-user ratelimit
+ # will kick in below.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ (medium, address)
+ )
+ raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+ identifier_dict = {"type": "m.id.user", "user": user_id}
+ # by this point, the identifier should be an m.id.user: if it's anything
+ # else, we haven't understood it.
+ if identifier_dict["type"] != "m.id.user":
+ raise SynapseError(400, "Unknown login identifier type")
+
+ username = identifier_dict.get("user")
+ if not username:
+ raise SynapseError(400, "User identifier is missing 'user' key")
+
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ # Check if we've hit the failed ratelimit (but don't update it)
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.ratelimit(
+ qualified_user_id.lower(), update=False
+ )
+
+ try:
+ return await self._validate_userid_login(username, login_submission)
+ except LoginError:
+ # The user has failed to log in, so we need to update the rate
+ # limiter. Using `can_do_action` avoids us raising a ratelimit
+ # exception and masking the LoginError. The actual ratelimiting
+ # should have happened above.
+ if ratelimit:
+ self._failed_login_attempts_ratelimiter.can_do_action(
+ qualified_user_id.lower()
+ )
+ raise
+
+ async def _validate_userid_login(
+ self, username: str, login_submission: Dict[str, Any],
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ """Helper for validate_login
+
+ Handles login, once we've mapped 3pids onto userids
+
+ Args:
+ username: the username, from the identifier dict
+ login_submission: the whole of the login submission
+ (including 'type' and other relevant fields)
+ Returns:
+ A tuple of the canonical user id, and optional callback
+ to be called once the access token and device id are issued
+ Raises:
+ StoreError if there was a problem accessing the database
+ SynapseError if there was a problem with the request
+ LoginError if there was an authentication problem.
+ """
+ if username.startswith("@"):
+ qualified_user_id = username
+ else:
+ qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+ login_type = login_submission.get("type")
+ # we already checked that we have a valid login type
+ assert isinstance(login_type, str)
+
+ known_login_type = False
+
+ for provider in self.password_providers:
supported_login_types = provider.get_supported_login_types()
if login_type not in supported_login_types:
# this password provider doesn't understand this login type
@@ -883,15 +1027,17 @@ class AuthHandler(BaseHandler):
result = await provider.check_auth(username, login_type, login_dict)
if result:
- if isinstance(result, str):
- result = (result, None)
return result
if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
known_login_type = True
+ # we've already checked that there is a (valid) password field
+ password = login_submission["password"]
+ assert isinstance(password, str)
+
canonical_user_id = await self._check_local_password(
- qualified_user_id, password # type: ignore
+ qualified_user_id, password
)
if canonical_user_id:
@@ -922,19 +1068,9 @@ class AuthHandler(BaseHandler):
unsuccessful, `user_id` and `callback` are both `None`.
"""
for provider in self.password_providers:
- if hasattr(provider, "check_3pid_auth"):
- # This function is able to return a deferred that either
- # resolves None, meaning authentication failure, or upon
- # success, to a str (which is the user_id) or a tuple of
- # (user_id, callback_func), where callback_func should be run
- # after we've finished everything else
- result = await provider.check_3pid_auth(medium, address, password)
- if result:
- # Check if the return value is a str or a tuple
- if isinstance(result, str):
- # If it's a str, set callback function to None
- result = (result, None)
- return result
+ result = await provider.check_3pid_auth(medium, address, password)
+ if result:
+ return result
return None, None
@@ -992,16 +1128,11 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- # This might return an awaitable, if it does block the log out
- # until it completes.
- result = provider.on_logged_out(
- user_id=user_info.user_id,
- device_id=user_info.device_id,
- access_token=access_token,
- )
- if inspect.isawaitable(result):
- await result
+ await provider.on_logged_out(
+ user_id=user_info.user_id,
+ device_id=user_info.device_id,
+ access_token=access_token,
+ )
# delete pushers associated with this access token
if user_info.token_id is not None:
@@ -1030,11 +1161,10 @@ class AuthHandler(BaseHandler):
# see if any of our auth providers want to know about this
for provider in self.password_providers:
- if hasattr(provider, "on_logged_out"):
- for token, token_id, device_id in tokens_and_devices:
- await provider.on_logged_out(
- user_id=user_id, device_id=device_id, access_token=token
- )
+ for token, token_id, device_id in tokens_and_devices:
+ await provider.on_logged_out(
+ user_id=user_id, device_id=device_id, access_token=token
+ )
# delete pushers associated with the access tokens
await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1358,3 +1488,127 @@ class MacaroonGenerator:
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
+
+
+class PasswordProvider:
+ """Wrapper for a password auth provider module
+
+ This class abstracts out all of the backwards-compatibility hacks for
+ password providers, to provide a consistent interface.
+ """
+
+ @classmethod
+ def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+ try:
+ pp = module(config=config, account_handler=module_api)
+ except Exception as e:
+ logger.error("Error while initializing %r: %s", module, e)
+ raise
+ return cls(pp, module_api)
+
+ def __init__(self, pp, module_api: ModuleApi):
+ self._pp = pp
+ self._module_api = module_api
+
+ self._supported_login_types = {}
+
+ # grandfather in check_password support
+ if hasattr(self._pp, "check_password"):
+ self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+ g = getattr(self._pp, "get_supported_login_types", None)
+ if g:
+ self._supported_login_types.update(g())
+
+ def __str__(self):
+ return str(self._pp)
+
+ def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+ """Get the login types supported by this password provider
+
+ Returns a map from a login type identifier (such as m.login.password) to an
+ iterable giving the fields which must be provided by the user in the submission
+ to the /login API.
+
+ This wrapper adds m.login.password to the list if the underlying password
+ provider supports the check_password() api.
+ """
+ return self._supported_login_types
+
+ async def check_auth(
+ self, username: str, login_type: str, login_dict: JsonDict
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ """Check if the user has presented valid login credentials
+
+ This wrapper also calls check_password() if the underlying password provider
+ supports the check_password() api and the login type is m.login.password.
+
+ Args:
+ username: user id presented by the client. Either an MXID or an unqualified
+ username.
+
+ login_type: the login type being attempted - one of the types returned by
+ get_supported_login_types()
+
+ login_dict: the dictionary of login secrets passed by the client.
+
+ Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+ user, and `callback` is an optional callback which will be called with the
+ result from the /login call (including access_token, device_id, etc.)
+ """
+ # first grandfather in a call to check_password
+ if login_type == LoginType.PASSWORD:
+ g = getattr(self._pp, "check_password", None)
+ if g:
+ qualified_user_id = self._module_api.get_qualified_user_id(username)
+ is_valid = await self._pp.check_password(
+ qualified_user_id, login_dict["password"]
+ )
+ if is_valid:
+ return qualified_user_id, None
+
+ g = getattr(self._pp, "check_auth", None)
+ if not g:
+ return None
+ result = await g(username, login_type, login_dict)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def check_3pid_auth(
+ self, medium: str, address: str, password: str
+ ) -> Optional[Tuple[str, Optional[Callable]]]:
+ g = getattr(self._pp, "check_3pid_auth", None)
+ if not g:
+ return None
+
+ # This function is able to return a deferred that either
+ # resolves None, meaning authentication failure, or upon
+ # success, to a str (which is the user_id) or a tuple of
+ # (user_id, callback_func), where callback_func should be run
+ # after we've finished everything else
+ result = await g(medium, address, password)
+
+ # Check if the return value is a str or a tuple
+ if isinstance(result, str):
+ # If it's a str, set callback function to None
+ return result, None
+
+ return result
+
+ async def on_logged_out(
+ self, user_id: str, device_id: Optional[str], access_token: str
+ ) -> None:
+ g = getattr(self._pp, "on_logged_out", None)
+ if not g:
+ return
+
+ # This might return an awaitable, if it does block the log out
+ # until it completes.
+ result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ if inspect.isawaitable(result):
+ await result
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from xml.etree import ElementTree as ET
from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
from synapse.http.site import SynapseRequest
from synapse.types import UserID, map_username_to_mxid_localpart
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -31,10 +34,10 @@ class CasHandler:
Utility class for to handle the response from a CAS SSO service.
Args:
- hs (synapse.server.HomeServer)
+ hs
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self._hostname = hs.hostname
self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
args["session"] = session
username, user_display_name = await self._validate_ticket(ticket, args)
- localpart = map_username_to_mxid_localpart(username)
- user_id = UserID(localpart, self._hostname).to_string()
- registered_user_id = await self._auth_handler.check_user_exists(user_id)
+ # Pull out the user-agent and IP from the request.
+ user_agent = request.get_user_agent("")
+ ip_address = self.hs.get_ip_from_request(request)
+
+ # Get the matrix ID from the CAS username.
+ user_id = await self._map_cas_user_to_matrix_user(
+ username, user_display_name, user_agent, ip_address
+ )
if session:
await self._auth_handler.complete_sso_ui_auth(
- registered_user_id, session, request,
+ user_id, session, request,
)
-
else:
- if not registered_user_id:
- # Pull out the user-agent and IP from the request.
- user_agent = request.get_user_agent("")
- ip_address = self.hs.get_ip_from_request(request)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=user_display_name,
- user_agent_ips=(user_agent, ip_address),
- )
+ # If this not a UI auth request than there must be a redirect URL.
+ assert client_redirect_url
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url
+ user_id, request, client_redirect_url
)
+
+ async def _map_cas_user_to_matrix_user(
+ self,
+ remote_user_id: str,
+ display_name: Optional[str],
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
+ """
+ Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+ Args:
+ remote_user_id: The username from the CAS response.
+ display_name: The display name from the CAS response.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+
+ Returns:
+ The user ID associated with this response.
+ """
+
+ localpart = map_username_to_mxid_localpart(remote_user_id)
+ user_id = UserID(localpart, self._hostname).to_string()
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+ # If the user does not exist, register it.
+ if not registered_user_id:
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=localpart,
+ default_display_name=display_name,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ return registered_user_id
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4efe6c530a..e808142365 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler()
+ self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites:
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room.room_id,
"leave",
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id)
try:
await self._room_member_handler.update_membership(
- create_requester(user),
+ create_requester(user, authenticated_entity=self._server_name),
user,
room_id,
"leave",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e03caea251..b9799090f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
from synapse.replication.http.federation import (
ReplicationCleanRoomRestServlet,
ReplicationFederationSendEventsRestServlet,
- ReplicationStoreRoomOnInviteRestServlet,
+ ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -153,12 +153,14 @@ class FederationHandler(BaseHandler):
self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
hs
)
- self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+ self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
hs
)
else:
self._device_list_updater = hs.get_device_handler().device_list_updater
- self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+ self._maybe_store_room_on_outlier_membership = (
+ self.store.maybe_store_room_on_outlier_membership
+ )
# When joining a room we need to queue any events for that room up.
# For each room, a list of (pdu, origin) tuples.
@@ -1618,7 +1620,7 @@ class FederationHandler(BaseHandler):
# keep a record of the room version, if we don't yet know it.
# (this may get overwritten if we later get a different room version in a
# join dance).
- await self._maybe_store_room_on_invite(
+ await self._maybe_store_room_on_outlier_membership(
room_id=event.room_id, room_version=room_version
)
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..9b3c6b4551 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
raise SynapseError(500, "An error was encountered when sending the email")
token_expires = (
- self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+ self.hs.get_clock().time_msec()
+ + self.hs.config.email_validation_token_lifetime
)
await self.store.start_or_continue_validation_session(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c6791fb912..96843338ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -472,7 +472,7 @@ class EventCreationHandler:
Returns:
Tuple of created event, Context
"""
- await self.auth.check_auth_blocking(requester.user.to_string())
+ await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"]
@@ -619,7 +619,13 @@ class EventCreationHandler:
if requester.app_service is not None:
return
- user_id = requester.user.to_string()
+ user_id = requester.authenticated_entity
+ if not user_id.startswith("@"):
+ # The authenticated entity might not be a user, e.g. if it's the
+ # server puppetting the user.
+ return
+
+ user = UserID.from_string(user_id)
# exempt the system notices user
if (
@@ -639,9 +645,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version:
return
- consent_uri = self._consent_uri_builder.build_user_consent_uri(
- requester.user.localpart
- )
+ consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@@ -1252,7 +1256,7 @@ class EventCreationHandler:
for user_id in members:
if not self.hs.is_mine_id(user_id):
continue
- requester = create_requester(user_id)
+ requester = create_requester(user_id, authenticated_entity=self.server_name)
try:
event, context = await self.create_event(
requester,
@@ -1273,11 +1277,6 @@ class EventCreationHandler:
requester, event, context, ratelimit=False, ignore_shadow_ban=True,
)
return True
- except ConsentNotGivenError:
- logger.info(
- "Failed to send dummy event into room %s for user %s due to "
- "lack of consent. Will try another user" % (room_id, user_id)
- )
except AuthError:
logger.info(
"Failed to send dummy event into room %s for user %s due to "
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..c605f7082a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -34,7 +35,8 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +85,12 @@ class OidcError(Exception):
return self.error
-class MappingException(Exception):
- """Used to catch errors when mapping the UserInfo object
- """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
"""Handles requests related to the OpenID Connect login flow.
"""
def __init__(self, hs: "HomeServer"):
- self.hs = hs
+ super().__init__(hs)
self._callback_url = hs.config.oidc_callback_url # type: str
self._scopes = hs.config.oidc_scopes # type: List[str]
self._user_profile_method = hs.config.oidc_user_profile_method # type: str
@@ -120,36 +117,13 @@ class OidcHandler:
self._http_client = hs.get_proxied_http_client()
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._datastore = hs.get_datastore()
- self._clock = hs.get_clock()
- self._hostname = hs.hostname # type: str
self._server_name = hs.config.server_name # type: str
self._macaroon_secret_key = hs.config.macaroon_secret_key
- self._error_template = hs.config.sso_error_template
# identifier for the external_ids table
self._auth_provider_id = "oidc"
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
-
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error. Those include
- well-known OAuth2/OIDC error types like invalid_request or
- access_denied.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def _validate_metadata(self):
"""Verifies the provider metadata.
@@ -571,7 +545,7 @@ class OidcHandler:
Since we might want to display OIDC-related errors in a user-friendly
way, we don't raise SynapseError from here. Instead, we call
- ``self._render_error`` which displays an HTML page for the error.
+ ``self._sso_handler.render_error`` which displays an HTML page for the error.
Most of the OpenID Connect logic happens here:
@@ -609,7 +583,7 @@ class OidcHandler:
if error != "access_denied":
logger.error("Error from the OIDC provider: %s %s", error, description)
- self._render_error(request, error, description)
+ self._sso_handler.render_error(request, error, description)
return
# otherwise, it is presumably a successful response. see:
@@ -619,7 +593,9 @@ class OidcHandler:
session = request.getCookie(SESSION_COOKIE_NAME) # type: Optional[bytes]
if session is None:
logger.info("No session cookie found")
- self._render_error(request, "missing_session", "No session cookie found")
+ self._sso_handler.render_error(
+ request, "missing_session", "No session cookie found"
+ )
return
# Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +613,9 @@ class OidcHandler:
# Check for the state query parameter
if b"state" not in request.args:
logger.info("State parameter is missing")
- self._render_error(request, "invalid_request", "State parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "State parameter is missing"
+ )
return
state = request.args[b"state"][0].decode()
@@ -651,17 +629,19 @@ class OidcHandler:
) = self._verify_oidc_session_token(session, state)
except MacaroonDeserializationException as e:
logger.exception("Invalid session")
- self._render_error(request, "invalid_session", str(e))
+ self._sso_handler.render_error(request, "invalid_session", str(e))
return
except MacaroonInvalidSignatureException as e:
logger.exception("Could not verify session")
- self._render_error(request, "mismatching_session", str(e))
+ self._sso_handler.render_error(request, "mismatching_session", str(e))
return
# Exchange the code with the provider
if b"code" not in request.args:
logger.info("Code parameter is missing")
- self._render_error(request, "invalid_request", "Code parameter is missing")
+ self._sso_handler.render_error(
+ request, "invalid_request", "Code parameter is missing"
+ )
return
logger.debug("Exchanging code")
@@ -670,7 +650,7 @@ class OidcHandler:
token = await self._exchange_code(code)
except OidcError as e:
logger.exception("Could not exchange code")
- self._render_error(request, e.error, e.error_description)
+ self._sso_handler.render_error(request, e.error, e.error_description)
return
logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +663,7 @@ class OidcHandler:
userinfo = await self._fetch_userinfo(token)
except Exception as e:
logger.exception("Could not fetch userinfo")
- self._render_error(request, "fetch_error", str(e))
+ self._sso_handler.render_error(request, "fetch_error", str(e))
return
else:
logger.debug("Extracting userinfo from id_token")
@@ -691,7 +671,7 @@ class OidcHandler:
userinfo = await self._parse_id_token(token, nonce=nonce)
except Exception as e:
logger.exception("Invalid id_token")
- self._render_error(request, "invalid_token", str(e))
+ self._sso_handler.render_error(request, "invalid_token", str(e))
return
# Pull out the user-agent and IP from the request.
@@ -705,7 +685,7 @@ class OidcHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +750,7 @@ class OidcHandler:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (ui_auth_session_id,)
)
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -845,7 +825,7 @@ class OidcHandler:
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
return now < expiry
async def _map_userinfo_to_user(
@@ -885,71 +865,77 @@ class OidcHandler:
# to be strings.
remote_user_id = str(remote_user_id)
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
- )
-
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id,
+ # Older mapping providers don't accept the `failures` argument, so we
+ # try and detect support.
+ mapper_signature = inspect.signature(
+ self._user_mapping_provider.map_user_attributes
)
+ supports_failures = "failures" in mapper_signature.parameters
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
-
- try:
- attributes = await self._user_mapping_provider.map_user_attributes(
- userinfo, token
- )
- except Exception as e:
- raise MappingException(
- "Could not extract user attributes from OIDC response: " + str(e)
- )
+ async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+ """
+ Call the mapping provider to map the OIDC userinfo and token to user attributes.
- logger.debug(
- "Retrieved user attributes from user mapping provider: %r", attributes
- )
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ if supports_failures:
+ attributes = await self._user_mapping_provider.map_user_attributes(
+ userinfo, token, failures
+ )
+ else:
+ # If the mapping provider does not support processing failures,
+ # do not continually generate the same Matrix ID since it will
+ # continue to already be in use. Note that the error raised is
+ # arbitrary and will get turned into a MappingException.
+ if failures:
+ raise MappingException(
+ "Mapping provider does not support de-duplicating Matrix IDs"
+ )
- if not attributes["localpart"]:
- raise MappingException("localpart is empty")
+ attributes = await self._user_mapping_provider.map_user_attributes( # type: ignore
+ userinfo, token
+ )
- localpart = map_username_to_mxid_localpart(attributes["localpart"])
+ return UserAttributes(**attributes)
- user_id = UserID(localpart, self._hostname).to_string()
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
- if users:
+ async def grandfather_existing_users() -> Optional[str]:
if self._allow_existing_users:
- if len(users) == 1:
- registered_user_id = next(iter(users))
- elif user_id in users:
- registered_user_id = user_id
- else:
- raise MappingException(
- "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
- user_id, list(users.keys())
+ # If allowing existing users we want to generate a single localpart
+ # and attempt to match it.
+ attributes = await oidc_response_to_user_attributes(failures=0)
+
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
+ if users:
+ # If an existing matrix ID is returned, then use it.
+ if len(users) == 1:
+ previously_registered_user_id = next(iter(users))
+ elif user_id in users:
+ previously_registered_user_id = user_id
+ else:
+ # Do not attempt to continue generating Matrix IDs.
+ raise MappingException(
+ "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+ user_id, users
+ )
)
- )
- else:
- # This mxid is taken
- raise MappingException("mxid '{}' is already taken".format(user_id))
- else:
- # It's the first time this user is logging in and the mapped mxid was
- # not taken, register the user
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=attributes["display_name"],
- user_agent_ips=(user_agent, ip_address),
- )
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id,
+
+ return previously_registered_user_id
+
+ return None
+
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ oidc_response_to_user_attributes,
+ grandfather_existing_users,
)
- return registered_user_id
-UserAttribute = TypedDict(
- "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+ "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
)
C = TypeVar("C")
@@ -992,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
raise NotImplementedError()
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
"""Map a `UserInfo` object into user attributes.
Args:
userinfo: An object representing the user given by the OIDC provider
token: A dict with the tokens returned by the provider
+ failures: How many times a call to this function with this
+ UserInfo has resulted in a failure.
Returns:
A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1098,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
return userinfo[self._config.subject_claim]
async def map_user_attributes(
- self, userinfo: UserInfo, token: Token
- ) -> UserAttribute:
+ self, userinfo: UserInfo, token: Token, failures: int
+ ) -> UserAttributeDict:
localpart = self._config.localpart_template.render(user=userinfo).strip()
+ # Ensure only valid characters are included in the MXID.
+ localpart = map_username_to_mxid_localpart(localpart)
+
+ # Append suffix integer if last call to this function failed to produce
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
+
display_name = None # type: Optional[str]
if self._config.display_name_template is not None:
display_name = self._config.display_name_template.render(
@@ -1111,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
if display_name == "":
display_name = None
- return UserAttribute(localpart=localpart, display_name=display_name)
+ return UserAttributeDict(localpart=localpart, display_name=display_name)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
extras = {} # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 426b58da9e..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -299,17 +299,22 @@ class PaginationHandler:
"""
return self._purges_by_id.get(purge_id)
- async def purge_room(self, room_id: str) -> None:
- """Purge the given room from the database"""
+ async def purge_room(self, room_id: str, force: bool = False) -> None:
+ """Purge the given room from the database.
+
+ Args:
+ room_id: room to be purged
+ force: set true to skip checking for joined users.
+ """
with await self.pagination_lock.write(room_id):
# check we know about the room
await self.store.get_room_version_id(room_id)
# first check that we have no users in this room
- joined = await self.store.is_host_joined(room_id, self._server_name)
-
- if joined:
- raise SynapseError(400, "Users are still joined to this room")
+ if not force:
+ joined = await self.store.is_host_joined(room_id, self._server_name)
+ if joined:
+ raise SynapseError(400, "Users are still joined to this room")
await self.storage.purge_events.purge_room(room_id)
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e014c9bb5..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
import abc
import logging
from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
from prometheus_client import Counter
from typing_extensions import ContextManager
@@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 74a1ddd780..dee0ef45e7 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms.
# This must be done by the target user himself.
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity,
+ )
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
@@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname
if by_admin:
- requester = create_requester(target_user)
+ requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c242c409cf..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -158,7 +158,8 @@ class ReceiptEventSource:
if from_key == to_key:
return [], to_key
- # We first need to fetch all new receipts
+ # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+ # by most recent.
rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
from_key=from_key, to_key=to_key
)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed1ff62599..0d85fd0868 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
"""Contains functions for registering clients."""
import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict
from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
from ._base import BaseHandler
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class RegistrationHandler(BaseHandler):
- def __init__(self, hs):
- """
-
- Args:
- hs (synapse.server.HomeServer):
- """
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid
+ self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker()
@@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler):
self.session_lifetime = hs.config.session_lifetime
async def check_username(
- self, localpart, guest_access_token=None, assigned_user_id=None
+ self,
+ localpart: str,
+ guest_access_token: Optional[str] = None,
+ assigned_user_id: Optional[str] = None,
):
if types.contains_invalid_mxid_characters(localpart):
raise SynapseError(
@@ -139,39 +143,45 @@ class RegistrationHandler(BaseHandler):
async def register_user(
self,
- localpart=None,
- password_hash=None,
- guest_access_token=None,
- make_guest=False,
- admin=False,
- threepid=None,
- user_type=None,
- default_display_name=None,
- address=None,
- bind_emails=[],
- by_admin=False,
- user_agent_ips=None,
- ):
+ localpart: Optional[str] = None,
+ password_hash: Optional[str] = None,
+ guest_access_token: Optional[str] = None,
+ make_guest: bool = False,
+ admin: bool = False,
+ threepid: Optional[dict] = None,
+ user_type: Optional[str] = None,
+ default_display_name: Optional[str] = None,
+ address: Optional[str] = None,
+ bind_emails: List[str] = [],
+ by_admin: bool = False,
+ user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ ) -> str:
"""Registers a new client on the server.
Args:
localpart: The local part of the user ID to register. If None,
one will be generated.
- password_hash (str|None): The hashed password to assign to this user so they can
+ password_hash: The hashed password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
- user_type (str|None): type of user. One of the values from
+ guest_access_token: The access token used when this was a guest
+ account.
+ make_guest: True if the the new user should be guest,
+ false to add a regular user account.
+ admin: True if the user should be registered as a server admin.
+ threepid: The threepid used for registering, if any.
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- default_display_name (unicode|None): if set, the new user's displayname
+ default_display_name: if set, the new user's displayname
will be set to this. Defaults to 'localpart'.
- address (str|None): the IP address used to perform the registration.
- bind_emails (List[str]): list of emails to bind to this account.
- by_admin (bool): True if this registration is being made via the
+ address: the IP address used to perform the registration.
+ bind_emails: list of emails to bind to this account.
+ by_admin: True if this registration is being made via the
admin api, otherwise False.
- user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+ user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
Returns:
- str: user_id
+ The registere user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -235,8 +245,10 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
fail_count = 0
- user = None
- while not user:
+ # If a default display name is not given, generate one.
+ generate_display_name = default_display_name is None
+ # This breaks on successful registration *or* errors after 10 failures.
+ while True:
# Fail after being unable to find a suitable ID a few times
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -245,7 +257,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
- if default_display_name is None:
+ if generate_display_name:
default_display_name = localpart
try:
await self.register_with_store(
@@ -261,8 +273,6 @@ class RegistrationHandler(BaseHandler):
break
except SynapseError:
# if user id is taken, just generate another
- user = None
- user_id = None
fail_count += 1
if not self.hs.config.user_consent_at_registration:
@@ -294,7 +304,7 @@ class RegistrationHandler(BaseHandler):
return user_id
- async def _create_and_join_rooms(self, user_id: str):
+ async def _create_and_join_rooms(self, user_id: str) -> None:
"""
Create the auto-join rooms and join or invite the user to them.
@@ -317,7 +327,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False
if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
)
# If the room requires an invite, add the user to the list of invites.
@@ -329,7 +340,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent.
requires_join = True
else:
- fake_requester = create_requester(user_id)
+ fake_requester = create_requester(
+ user_id, authenticated_entity=self._server_name
+ )
# Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -362,7 +375,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it.
if requires_join:
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=info["room_id"],
# Since it was just created, there are no remote hosts.
@@ -370,15 +385,10 @@ class RegistrationHandler(BaseHandler):
action="join",
ratelimit=False,
)
-
- except ConsentNotGivenError as e:
- # Technically not necessary to pull out this error though
- # moving away from bare excepts is a good thing to do.
- logger.error("Failed to join new user to %r: %r", r, e)
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _join_rooms(self, user_id: str):
+ async def _join_rooms(self, user_id: str) -> None:
"""
Join or invite the user to the auto-join rooms.
@@ -424,9 +434,13 @@ class RegistrationHandler(BaseHandler):
# Send the invite, if necessary.
if requires_invite:
+ # If an invite is required, there must be a auto-join user ID.
+ assert self.hs.config.registration.auto_join_user_id
+
await room_member_handler.update_membership(
requester=create_requester(
- self.hs.config.registration.auto_join_user_id
+ self.hs.config.registration.auto_join_user_id,
+ authenticated_entity=self._server_name,
),
target=UserID.from_string(user_id),
room_id=room_id,
@@ -437,7 +451,9 @@ class RegistrationHandler(BaseHandler):
# Send the join.
await room_member_handler.update_membership(
- requester=create_requester(user_id),
+ requester=create_requester(
+ user_id, authenticated_entity=self._server_name
+ ),
target=UserID.from_string(user_id),
room_id=room_id,
remote_room_hosts=remote_room_hosts,
@@ -452,7 +468,7 @@ class RegistrationHandler(BaseHandler):
except Exception as e:
logger.error("Failed to join new user to %r: %r", r, e)
- async def _auto_join_rooms(self, user_id: str):
+ async def _auto_join_rooms(self, user_id: str) -> None:
"""Automatically joins users to auto join rooms - creating the room in the first place
if the user is the first to be created.
@@ -475,16 +491,16 @@ class RegistrationHandler(BaseHandler):
else:
await self._join_rooms(user_id)
- async def post_consent_actions(self, user_id):
+ async def post_consent_actions(self, user_id: str) -> None:
"""A series of registration actions that can only be carried out once consent
has been granted
Args:
- user_id (str): The user to join
+ user_id: The user to join
"""
await self._auto_join_rooms(user_id)
- async def appservice_register(self, user_localpart, as_token):
+ async def appservice_register(self, user_localpart: str, as_token: str) -> str:
user = UserID(user_localpart, self.hs.hostname)
user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token)
@@ -509,7 +525,9 @@ class RegistrationHandler(BaseHandler):
)
return user_id
- def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+ def check_user_id_not_appservice_exclusive(
+ self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+ ) -> None:
# don't allow people to register the server notices mxid
if self._server_notices_mxid is not None:
if user_id == self._server_notices_mxid:
@@ -533,12 +551,12 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE,
)
- def check_registration_ratelimit(self, address):
+ def check_registration_ratelimit(self, address: Optional[str]) -> None:
"""A simple helper method to check whether the registration rate limit has been hit
for a given IP address
Args:
- address (str|None): the IP address used to perform the registration. If this is
+ address: the IP address used to perform the registration. If this is
None, no ratelimiting will be performed.
Raises:
@@ -549,42 +567,39 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter.ratelimit(address)
- def register_with_store(
+ async def register_with_store(
self,
- user_id,
- password_hash=None,
- was_guest=False,
- make_guest=False,
- appservice_id=None,
- create_profile_with_displayname=None,
- admin=False,
- user_type=None,
- address=None,
- shadow_banned=False,
- ):
+ user_id: str,
+ password_hash: Optional[str] = None,
+ was_guest: bool = False,
+ make_guest: bool = False,
+ appservice_id: Optional[str] = None,
+ create_profile_with_displayname: Optional[str] = None,
+ admin: bool = False,
+ user_type: Optional[str] = None,
+ address: Optional[str] = None,
+ shadow_banned: bool = False,
+ ) -> None:
"""Register user in the datastore.
Args:
- user_id (str): The desired user ID to register.
- password_hash (str|None): Optional. The password hash for this user.
- was_guest (bool): Optional. Whether this is a guest account being
+ user_id: The desired user ID to register.
+ password_hash: Optional. The password hash for this user.
+ was_guest: Optional. Whether this is a guest account being
upgraded to a non-guest account.
- make_guest (boolean): True if the the new user should be guest,
+ make_guest: True if the the new user should be guest,
false to add a regular user account.
- appservice_id (str|None): The ID of the appservice registering the user.
- create_profile_with_displayname (unicode|None): Optionally create a
+ appservice_id: The ID of the appservice registering the user.
+ create_profile_with_displayname: Optionally create a
profile for the user, setting their displayname to the given value
- admin (boolean): is an admin user?
- user_type (str|None): type of user. One of the values from
+ admin: is an admin user?
+ user_type: type of user. One of the values from
api.constants.UserTypes, or None for a normal user.
- address (str|None): the IP address used to perform the registration.
- shadow_banned (bool): Whether to shadow-ban the user
-
- Returns:
- Awaitable
+ address: the IP address used to perform the registration.
+ shadow_banned: Whether to shadow-ban the user
"""
if self.hs.config.worker_app:
- return self._register_client(
+ await self._register_client(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -597,7 +612,7 @@ class RegistrationHandler(BaseHandler):
shadow_banned=shadow_banned,
)
else:
- return self.store.register_user(
+ await self.store.register_user(
user_id=user_id,
password_hash=password_hash,
was_guest=was_guest,
@@ -610,22 +625,24 @@ class RegistrationHandler(BaseHandler):
)
async def register_device(
- self, user_id, device_id, initial_display_name, is_guest=False
- ):
+ self,
+ user_id: str,
+ device_id: Optional[str],
+ initial_display_name: Optional[str],
+ is_guest: bool = False,
+ ) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
The access token will be limited by the homeserver's session_lifetime config.
Args:
- user_id (str): full canonical @user:id
- device_id (str|None): The device ID to check, or None to generate
- a new one.
- initial_display_name (str|None): An optional display name for the
- device.
- is_guest (bool): Whether this is a guest account
+ user_id: full canonical @user:id
+ device_id: The device ID to check, or None to generate a new one.
+ initial_display_name: An optional display name for the device.
+ is_guest: Whether this is a guest account
Returns:
- tuple[str, str]: Tuple of device ID and access token
+ Tuple of device ID and access token
"""
if self.hs.config.worker_app:
@@ -645,7 +662,7 @@ class RegistrationHandler(BaseHandler):
)
valid_until_ms = self.clock.time_msec() + self.session_lifetime
- device_id = await self.device_handler.check_device_registered(
+ registered_device_id = await self.device_handler.check_device_registered(
user_id, device_id, initial_display_name
)
if is_guest:
@@ -655,20 +672,21 @@ class RegistrationHandler(BaseHandler):
)
else:
access_token = await self._auth_handler.get_access_token_for_user_id(
- user_id, device_id=device_id, valid_until_ms=valid_until_ms
+ user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
)
- return (device_id, access_token)
+ return (registered_device_id, access_token)
- async def post_registration_actions(self, user_id, auth_result, access_token):
+ async def post_registration_actions(
+ self, user_id: str, auth_result: dict, access_token: Optional[str]
+ ) -> None:
"""A user has completed registration
Args:
- user_id (str): The user ID that consented
- auth_result (dict): The authenticated credentials of the newly
- registered user.
- access_token (str|None): The access token of the newly logged in
- device, or None if `inhibit_login` enabled.
+ user_id: The user ID that consented
+ auth_result: The authenticated credentials of the newly registered user.
+ access_token: The access token of the newly logged in device, or
+ None if `inhibit_login` enabled.
"""
if self.hs.config.worker_app:
await self._post_registration_client(
@@ -694,19 +712,20 @@ class RegistrationHandler(BaseHandler):
if auth_result and LoginType.TERMS in auth_result:
await self._on_user_consented(user_id, self.hs.config.user_consent_version)
- async def _on_user_consented(self, user_id, consent_version):
+ async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
"""A user consented to the terms on registration
Args:
- user_id (str): The user ID that consented.
- consent_version (str): version of the policy the user has
- consented to.
+ user_id: The user ID that consented.
+ consent_version: version of the policy the user has consented to.
"""
logger.info("%s has consented to the privacy policy", user_id)
await self.store.user_set_consent_version(user_id, consent_version)
await self.post_consent_actions(user_id)
- async def _register_email_threepid(self, user_id, threepid, token):
+ async def _register_email_threepid(
+ self, user_id: str, threepid: dict, token: Optional[str]
+ ) -> None:
"""Add an email address as a 3pid identifier
Also adds an email pusher for the email address, if configured in the
@@ -715,10 +734,9 @@ class RegistrationHandler(BaseHandler):
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.email.identity auth response
- token (str|None): access_token for the user, or None if not logged
- in.
+ user_id: id of user
+ threepid: m.login.email.identity auth response
+ token: access_token for the user, or None if not logged in.
"""
reqd = ("medium", "address", "validated_at")
if any(x not in threepid for x in reqd):
@@ -744,6 +762,8 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an
# invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token)
+ # The token better still exist.
+ assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
@@ -758,14 +778,14 @@ class RegistrationHandler(BaseHandler):
data={},
)
- async def _register_msisdn_threepid(self, user_id, threepid):
+ async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
"""Add a phone number as a 3pid identifier
Must be called on master.
Args:
- user_id (str): id of user
- threepid (object): m.login.msisdn auth response
+ user_id: id of user
+ threepid: m.login.msisdn auth response
"""
try:
assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e73031475f..930047e730 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
if (
self._server_notices_mxid is not None
@@ -1257,7 +1257,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,)
)
- room_creator_requester = create_requester(new_room_user_id)
+ room_creator_requester = create_requester(
+ new_room_user_id, authenticated_entity=requester_user_id
+ )
info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester,
@@ -1297,7 +1299,9 @@ class RoomShutdownHandler:
try:
# Kick users from room
- target_requester = create_requester(user_id)
+ target_requester = create_requester(
+ user_id, authenticated_entity=requester_user_id
+ )
_, stream_id = await self.room_member_handler.update_membership(
requester=target_requester,
target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cd858b7db..c002886324 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
-from synapse.storage.roommember import RoomsForUser
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_left_room
@@ -347,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# later on.
content = dict(content)
- if not self.allow_per_room_profiles or requester.shadow_banned:
+ # allow the server notices mxid to set room-level profile
+ is_requester_server_notices_user = (
+ self._server_notices_mxid is not None
+ and requester.user.to_string() == self._server_notices_mxid
+ )
+
+ if (
+ not self.allow_per_room_profiles and not is_requester_server_notices_user
+ ) or requester.shadow_banned:
# Strip profile data, knowing that new profile data will be added to the
# event's content in event_creation_handler.create_event() using the target's
# global profile.
@@ -515,10 +522,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
- invite = await self.store.get_invite_for_local_user_in_room(
- user_id=target.to_string(), room_id=room_id
- ) # type: Optional[RoomsForUser]
- if not invite:
+ (
+ current_membership_type,
+ current_membership_event_id,
+ ) = await self.store.get_local_current_membership_for_user_in_room(
+ target.to_string(), room_id
+ )
+ if (
+ current_membership_type != Membership.INVITE
+ or not current_membership_event_id
+ ):
logger.info(
"%s sent a leave request to %s, but that is not an active room "
"on this server, and there is no pending invite",
@@ -528,6 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(404, "Not a known room")
+ invite = await self.store.get_event(current_membership_event_id)
logger.info(
"%s rejects invite to %s from %s", target, room_id, invite.sender
)
@@ -965,6 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room")
+ self._server_name = hs.hostname
async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str]
@@ -1059,7 +1074,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id
# The room is too large. Leave.
- requester = types.create_requester(user, None, False, False, None)
+ requester = types.create_requester(
+ user, authenticated_entity=self._server_name
+ )
await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave"
)
@@ -1104,32 +1121,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
- return await self._locally_reject_invite(
+ return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
- async def _locally_reject_invite(
+ async def _generate_local_out_of_band_leave(
self,
- invite_event: EventBase,
+ previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
- """Generate a local invite rejection
+ """Generate a local leave event for a room
- This is called after we fail to reject an invite via a remote server. It
- generates an out-of-band membership event locally.
+ This can be called after we e.g fail to reject an invite via a remote server.
+ It generates an out-of-band membership event locally.
Args:
- invite_event: the invite to be rejected
+ previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
Normally an empty dict.
- """
- room_id = invite_event.room_id
- target_user = invite_event.state_key
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ room_id = previous_membership_event.room_id
+ target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@@ -1141,12 +1160,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
- # the auth events for the new event are the same as that of the invite, plus
- # the invite itself.
+ # the auth events for the new event are the same as that of the previous event, plus
+ # the event itself.
#
- # the prev_events are just the invite.
- prev_event_ids = [invite_event.event_id]
- auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+ # the prev_events consist solely of the previous membership event.
+ prev_event_ids = [previous_membership_event.event_id]
+ auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
- import synapse.server
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
-class MappingException(Exception):
- """Used to catch errors when mapping the SAML2 response to a user."""
-
-
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
-class SamlHandler:
- def __init__(self, hs: "synapse.server.HomeServer"):
- self.hs = hs
+class SamlHandler(BaseHandler):
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
- self._auth = hs.get_auth()
+ self._saml_idp_entityid = hs.config.saml2_idp_entityid
self._auth_handler = hs.get_auth_handler()
self._registration_handler = hs.get_registration_handler()
- self._clock = hs.get_clock()
- self._datastore = hs.get_datastore()
- self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
# a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
- def _render_error(
- self, request, error: str, error_description: Optional[str] = None
- ) -> None:
- """Render the error template and respond to the request with it.
-
- This is used to show errors to the user. The template of this page can
- be found under `synapse/res/templates/sso_error.html`.
+ self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
- Args:
- request: The incoming request from the browser.
- We'll respond with an HTML page describing the error.
- error: A technical identifier for this error.
- error_description: A human-readable description of the error.
- """
- html = self._error_template.render(
- error=error, error_description=error_description
- )
- respond_with_html(request, 400, html)
+ self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
URL to redirect to
"""
reqid, info = self._saml_client.prepare_for_authenticate(
- relay_state=client_redirect_url
+ entityid=self._saml_idp_entityid, relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
- now = self._clock.time_msec()
+ now = self.clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
)
@@ -171,12 +148,12 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
- self._render_error(
+ self._sso_handler.render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
- self._render_error(
+ self._sso_handler.render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
return
if saml2_auth.not_signed:
- self._render_error(
+ self._sso_handler.render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
@@ -210,7 +187,7 @@ class SamlHandler:
# attributes.
for requirement in self._saml2_attribute_requirements:
if not _check_attribute_requirement(saml2_auth.ava, requirement):
- self._render_error(
+ self._sso_handler.render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
@@ -226,7 +203,7 @@ class SamlHandler:
)
except MappingException as e:
logger.exception("Could not map user")
- self._render_error(request, "mapping_error", str(e))
+ self._sso_handler.render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
@@ -272,20 +249,26 @@ class SamlHandler:
"Failed to extract remote user id from SAML response"
)
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- # first of all, check if we already have a mapping for this user
- logger.info(
- "Looking for existing mapping for user %s:%s",
- self._auth_provider_id,
- remote_user_id,
+ async def saml_response_to_remapped_user_attributes(
+ failures: int,
+ ) -> UserAttributes:
+ """
+ Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+ This is backwards compatibility for abstraction for the SSO handler.
+ """
+ # Call the mapping provider.
+ result = self._user_mapping_provider.saml_response_to_user_attributes(
+ saml2_auth, failures, client_redirect_url
)
- registered_user_id = await self._datastore.get_user_by_external_id(
- self._auth_provider_id, remote_user_id
+ # Remap some of the results.
+ return UserAttributes(
+ localpart=result.get("mxid_localpart"),
+ display_name=result.get("displayname"),
+ emails=result.get("emails", []),
)
- if registered_user_id is not None:
- logger.info("Found existing mapping %s", registered_user_id)
- return registered_user_id
+ async def grandfather_existing_users() -> Optional[str]:
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
if (
@@ -294,75 +277,35 @@ class SamlHandler:
):
attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
user_id = UserID(
- map_username_to_mxid_localpart(attrval), self._hostname
+ map_username_to_mxid_localpart(attrval), self.server_name
).to_string()
- logger.info(
+
+ logger.debug(
"Looking for existing account based on mapped %s %s",
self._grandfathered_mxid_source_attribute,
user_id,
)
- users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+ users = await self.store.get_users_by_id_case_insensitive(user_id)
if users:
registered_user_id = list(users.keys())[0]
logger.info("Grandfathering mapping to %s", registered_user_id)
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
- )
return registered_user_id
- # Map saml response to user attributes using the configured mapping provider
- for i in range(1000):
- attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
- saml2_auth, i, client_redirect_url=client_redirect_url,
- )
-
- logger.debug(
- "Retrieved SAML attributes from user mapping provider: %s "
- "(attempt %d)",
- attribute_dict,
- i,
- )
-
- localpart = attribute_dict.get("mxid_localpart")
- if not localpart:
- raise MappingException(
- "Error parsing SAML2 response: SAML mapping provider plugin "
- "did not return a mxid_localpart value"
- )
-
- displayname = attribute_dict.get("displayname")
- emails = attribute_dict.get("emails", [])
-
- # Check if this mxid already exists
- if not await self._datastore.get_users_by_id_case_insensitive(
- UserID(localpart, self._hostname).to_string()
- ):
- # This mxid is free
- break
- else:
- # Unable to generate a username in 1000 iterations
- # Break and return error to the user
- raise MappingException(
- "Unable to generate a Matrix ID from the SAML response"
- )
+ return None
- logger.info("Mapped SAML user to local part %s", localpart)
-
- registered_user_id = await self._registration_handler.register_user(
- localpart=localpart,
- default_display_name=displayname,
- bind_emails=emails,
- user_agent_ips=(user_agent, ip_address),
- )
-
- await self._datastore.record_user_external_id(
- self._auth_provider_id, remote_user_id, registered_user_id
+ with (await self._mapping_lock.queue(self._auth_provider_id)):
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
)
- return registered_user_id
def expire_sessions(self):
- expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+ expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
for reqid, data in self._outstanding_requests_dict.items():
if data.creation_time < expire_before:
@@ -474,11 +417,11 @@ class DefaultSamlMappingProvider:
)
# Use the configured mapper for this mxid_source
- base_mxid_localpart = self._mxid_mapper(mxid_source)
+ localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
- # a usable mxid
- localpart = base_mxid_localpart + (str(failures) if failures else "")
+ # a usable mxid.
+ localpart += str(failures) if failures else ""
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..47ad96f97e
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+ """Used to catch errors when mapping an SSO response to user attributes.
+
+ Note that the msg that is raised is shown to end-users.
+ """
+
+
+@attr.s
+class UserAttributes:
+ localpart = attr.ib(type=str)
+ display_name = attr.ib(type=Optional[str], default=None)
+ emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
+class SsoHandler(BaseHandler):
+ # The number of attempts to ask the mapping provider for when generating an MXID.
+ _MAP_USERNAME_RETRIES = 1000
+
+ def __init__(self, hs: "HomeServer"):
+ super().__init__(hs)
+ self._registration_handler = hs.get_registration_handler()
+ self._error_template = hs.config.sso_error_template
+
+ def render_error(
+ self, request, error: str, error_description: Optional[str] = None
+ ) -> None:
+ """Renders the error template and responds with it.
+
+ This is used to show errors to the user. The template of this page can
+ be found under `synapse/res/templates/sso_error.html`.
+
+ Args:
+ request: The incoming request from the browser.
+ We'll respond with an HTML page describing the error.
+ error: A technical identifier for this error.
+ error_description: A human-readable description of the error.
+ """
+ html = self._error_template.render(
+ error=error, error_description=error_description
+ )
+ respond_with_html(request, 400, html)
+
+ async def get_sso_user_by_remote_user_id(
+ self, auth_provider_id: str, remote_user_id: str
+ ) -> Optional[str]:
+ """
+ Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+ If the user has not been seen yet, this will return None.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The user ID according to the remote IdP. This might
+ be an e-mail address, a GUID, or some other form. It must be
+ unique and immutable.
+
+ Returns:
+ The mxid of a previously seen user.
+ """
+ logger.debug(
+ "Looking for existing mapping for user %s:%s",
+ auth_provider_id,
+ remote_user_id,
+ )
+
+ # Check if we already have a mapping for this user.
+ previously_registered_user_id = await self.store.get_user_by_external_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ # A match was found, return the user ID.
+ if previously_registered_user_id is not None:
+ logger.info(
+ "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+ auth_provider_id,
+ remote_user_id,
+ previously_registered_user_id,
+ )
+ return previously_registered_user_id
+
+ # No match.
+ return None
+
+ async def get_mxid_from_sso(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+ ) -> str:
+ """
+ Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+ This first checks if the SSO ID has previously been linked to a matrix ID,
+ if it has that matrix ID is returned regardless of the current mapping
+ logic.
+
+ If a callable is provided for grandfathering users, it is called and can
+ potentially return a matrix ID to use. If it does, the SSO ID is linked to
+ this matrix ID for subsequent calls.
+
+ The mapping function is called (potentially multiple times) to generate
+ a localpart for the user.
+
+ If an unused localpart is generated, the user is registered from the
+ given user-agent and IP address and the SSO ID is linked to this matrix
+ ID for subsequent calls.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ user_agent: The user agent of the client making the request.
+ ip_address: The IP address of the client making the request.
+ sso_to_matrix_id_mapper: A callable to generate the user attributes.
+ The only parameter is an integer which represents the amount of
+ times the returned mxid localpart mapping has failed.
+
+ It is expected that the mapper can raise two exceptions, which
+ will get passed through to the caller:
+
+ MappingException if there was a problem mapping the response
+ to the user.
+ RedirectException to redirect to an additional page (e.g.
+ to prompt the user for more information).
+ grandfather_existing_users: A callable which can return an previously
+ existing matrix ID. The SSO ID is then linked to the returned
+ matrix ID.
+
+ Returns:
+ The user ID associated with the SSO response.
+
+ Raises:
+ MappingException if there was a problem mapping the response to a user.
+ RedirectException: if the mapping provider needs to redirect the user
+ to an additional page. (e.g. to prompt for more information)
+
+ """
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+ if previously_registered_user_id:
+ return previously_registered_user_id
+
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ for i in range(self._MAP_USERNAME_RETRIES):
+ try:
+ attributes = await sso_to_matrix_id_mapper(i)
+ except (RedirectException, MappingException):
+ # Mapping providers are allowed to issue a redirect (e.g. to ask
+ # the user for more information) and can issue a mapping exception
+ # if a name cannot be generated.
+ raise
+ except Exception as e:
+ # Any other exception is unexpected.
+ raise MappingException(
+ "Could not extract user attributes from SSO response."
+ ) from e
+
+ logger.debug(
+ "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+ attributes,
+ i,
+ )
+
+ if not attributes.localpart:
+ raise MappingException(
+ "Error parsing SSO response: SSO mapping provider plugin "
+ "did not return a localpart value"
+ )
+
+ # Check if this mxid already exists
+ user_id = UserID(attributes.localpart, self.server_name).to_string()
+ if not await self.store.get_users_by_id_case_insensitive(user_id):
+ # This mxid is free
+ break
+ else:
+ # Unable to generate a username in 1000 iterations
+ # Break and return error to the user
+ raise MappingException(
+ "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ # Since the localpart is provided via a potentially untrusted module,
+ # ensure the MXID is valid before registering.
+ if contains_invalid_mxid_characters(attributes.localpart):
+ raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+ logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+ registered_user_id = await self._registration_handler.register_user(
+ localpart=attributes.localpart,
+ default_display_name=attributes.display_name,
+ bind_emails=attributes.emails,
+ user_agent_ips=[(user_agent, ip_address)],
+ )
+
+ await self.store.record_user_external_id(
+ auth_provider_id, remote_user_id, registered_user_id
+ )
+ return registered_user_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 32e53c2d25..9827c7eb8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
+ Requester,
RoomStreamToken,
StateMap,
StreamToken,
@@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user(
self,
+ requester: Requester,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
@@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur)
user_id = sync_config.user.to_string()
- await self.auth.check_auth_blocking(user_id)
+ await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap(
sync_config.request_key,
diff --git a/synapse/http/client.py b/synapse/http/client.py
index f409368802..e5b13593f2 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-import urllib
+import urllib.parse
from io import BytesIO
from typing import (
+ TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,7 +32,7 @@ from typing import (
import treq
from canonicaljson import encode_canonical_json
-from netaddr import IPAddress
+from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -39,6 +40,8 @@ from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
+ IAddress,
+ IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@ from twisted.web.client import (
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IBodyProducer, IResponse
from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
-def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
+def check_against_blacklist(
+ ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
"""
+ Compares an IP address to allowed and disallowed IP sets.
+
Args:
- ip_address (netaddr.IPAddress)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ ip_address: The IP address to check
+ ip_whitelist: Allowed IP addresses.
+ ip_blacklist: Disallowed IP addresses.
+
+ Returns:
+ True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""
- def __init__(self, reactor, ip_whitelist, ip_blacklist):
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
"""
Args:
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ reactor: The twisted reactor.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def resolveHostName(self, recv, hostname, portNumber=0):
+ def resolveHostName(
+ self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+ ) -> IResolutionReceiver:
r = recv()
- addresses = []
+ addresses = [] # type: List[IAddress]
- def _callback():
+ def _callback() -> None:
r.resolutionBegan(None)
has_bad_ip = False
@@ -161,15 +181,15 @@ class IPBlacklistingResolver:
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
- def resolutionBegan(resolutionInProgress):
+ def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass
@staticmethod
- def addressResolved(address):
+ def addressResolved(address: IAddress) -> None:
addresses.append(address)
@staticmethod
- def resolutionComplete():
+ def resolutionComplete() -> None:
_callback()
self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""
- def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
+ def __init__(
+ self,
+ agent: IAgent,
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ ):
"""
Args:
- agent (twisted.web.client.Agent): The Agent to wrap.
- reactor (twisted.internet.reactor)
- ip_whitelist (netaddr.IPSet)
- ip_blacklist (netaddr.IPSet)
+ agent: The Agent to wrap.
+ ip_whitelist: IP addresses to allow.
+ ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))
try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:
def __init__(
self,
- hs,
- treq_args={},
- ip_whitelist=None,
- ip_blacklist=None,
- http_proxy=None,
- https_proxy=None,
+ hs: "HomeServer",
+ treq_args: Dict[str, Any] = {},
+ ip_whitelist: Optional[IPSet] = None,
+ ip_blacklist: Optional[IPSet] = None,
+ http_proxy: Optional[bytes] = None,
+ https_proxy: Optional[bytes] = None,
):
"""
Args:
- hs (synapse.server.HomeServer)
- treq_args (dict): Extra keyword arguments to be given to treq.request.
- ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
+ hs
+ treq_args: Extra keyword arguments to be given to treq.request.
+ ip_blacklist: The IP addresses that are blacklisted that
we may not request.
- ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
+ ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
- http_proxy (bytes): proxy server to use for http connections. host[:port]
- https_proxy (bytes): proxy server to use for https connections. host[:port]
+ http_proxy: proxy server to use for http connections. host[:port]
+ https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs
@@ -306,7 +336,6 @@ class SimpleHttpClient:
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
- self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -397,7 +426,7 @@ class SimpleHttpClient:
async def post_urlencoded_get_json(
self,
uri: str,
- args: Mapping[str, Union[str, List[str]]] = {},
+ args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,9 +451,7 @@ class SimpleHttpClient:
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
- query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
- "utf8"
- )
+ query_bytes = encode_query_args(args)
actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
@@ -432,7 +459,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ class SimpleHttpClient:
)
async def get_json(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ class SimpleHttpClient:
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ class SimpleHttpClient:
self,
uri: str,
json_body: Any,
- args: QueryParams = {},
+ args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ class SimpleHttpClient:
ValueError: if the response was not JSON
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
json_str = encode_canonical_json(json_body)
@@ -558,7 +588,7 @@ class SimpleHttpClient:
b"Accept": [b"application/json"],
}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ class SimpleHttpClient:
)
async def get_raw(
- self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
+ self,
+ uri: str,
+ args: Optional[QueryParams] = None,
+ headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ class SimpleHttpClient:
HttpResponseException on a non-2xx HTTP response.
"""
- if len(args):
- query_bytes = urllib.parse.urlencode(args, True)
- uri = "%s?%s" % (uri, query_bytes)
+ if args:
+ query_str = urllib.parse.urlencode(args, True)
+ uri = "%s?%s" % (uri, query_str)
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", uri, headers=Headers(actual_headers))
@@ -641,7 +674,7 @@ class SimpleHttpClient:
actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
- actual_headers.update(headers)
+ actual_headers.update(headers) # type: ignore
response = await self.request("GET", url, headers=Headers(actual_headers))
@@ -649,12 +682,13 @@ class SimpleHttpClient:
if (
b"Content-Length" in resp_headers
+ and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
- logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
+ logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
- "Requested file is too large > %r bytes" % (self.max_size,),
+ "Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)
@@ -668,7 +702,7 @@ class SimpleHttpClient:
try:
length = await make_deferred_yieldable(
- _readBodyToFile(response, output_stream, max_size)
+ readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
-
-
class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
+ def __init__(
+ self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
+ ):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
- def dataReceived(self, data):
+ def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred = defer.Deferred()
self.transport.loseConnection()
- def connectionLost(self, reason):
+ def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol):
self.deferred.errback(reason)
-# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
-# The two should be factored out.
+def readBodyToFile(
+ response: IResponse, stream: BinaryIO, max_size: Optional[int]
+) -> defer.Deferred:
+ """
+ Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+ Args:
+ response: The HTTP response to read from.
+ stream: The file-object to write to.
+ max_size: The maximum file size to allow.
+
+ Returns:
+ A Deferred which resolves to the length of the read body.
+ """
-def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
-def encode_urlencode_args(args):
- return {k: encode_urlencode_arg(v) for k, v in args.items()}
+def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
+ """
+ Encodes a map of query arguments to bytes which can be appended to a URL.
+ Args:
+ args: The query arguments, a mapping of string to string or list of strings.
+
+ Returns:
+ The query arguments encoded as bytes.
+ """
+ if args is None:
+ return b""
-def encode_urlencode_arg(arg):
- if isinstance(arg, str):
- return arg.encode("utf-8")
- elif isinstance(arg, list):
- return [encode_urlencode_arg(i) for i in arg]
- else:
- return arg
+ encoded_args = {}
+ for k, vs in args.items():
+ if isinstance(vs, str):
+ vs = [vs]
+ encoded_args[k] = [v.encode("utf8") for v in vs]
+ query_str = urllib.parse.urlencode(encoded_args, True)
-def _print_ex(e):
- if hasattr(e, "reasons") and e.reasons:
- for ex in e.reasons:
- _print_ex(ex)
- else:
- logger.exception(e)
+ return query_str.encode("utf8")
class InsecureInterceptableContextFactory(ssl.ContextFactory):
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-import urllib
-from typing import List
+import urllib.parse
+from typing import List, Optional
from netaddr import AddrFormatError, IPAddress
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
-from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import Agent, HTTPConnectionPool
+from twisted.internet.interfaces import (
+ IProtocolFactory,
+ IReactorCore,
+ IStreamClientEndpoint,
+)
+from twisted.web.client import URI, Agent, HTTPConnectionPool
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IAgentEndpointFactory
+from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -44,30 +48,30 @@ class MatrixFederationAgent:
Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.)
Args:
- reactor (IReactor): twisted reactor to use for underlying requests
+ reactor: twisted reactor to use for underlying requests
- tls_client_options_factory (FederationPolicyForHTTPS|None):
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- user_agent (bytes):
+ user_agent:
The user agent header to use for federation requests.
- _srv_resolver (SrvResolver|None):
- SRVResolver impl to use for looking up SRV records. None to use a default
- implementation.
+ _srv_resolver:
+ SrvResolver implementation to use for looking up SRV records. None
+ to use a default implementation.
- _well_known_resolver (WellKnownResolver|None):
+ _well_known_resolver:
WellKnownResolver to use to perform well-known lookups. None to use a
default implementation.
"""
def __init__(
self,
- reactor,
- tls_client_options_factory,
- user_agent,
- _srv_resolver=None,
- _well_known_resolver=None,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ user_agent: bytes,
+ _srv_resolver: Optional[SrvResolver] = None,
+ _well_known_resolver: Optional[WellKnownResolver] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -99,15 +103,20 @@ class MatrixFederationAgent:
self._well_known_resolver = _well_known_resolver
@defer.inlineCallbacks
- def request(self, method, uri, headers=None, bodyProducer=None):
+ def request(
+ self,
+ method: bytes,
+ uri: bytes,
+ headers: Optional[Headers] = None,
+ bodyProducer: Optional[IBodyProducer] = None,
+ ) -> defer.Deferred:
"""
Args:
- method (bytes): HTTP method: GET/POST/etc
- uri (bytes): Absolute URI to be retrieved
- headers (twisted.web.http_headers.Headers|None):
- HTTP headers to send with the request, or None to
- send no extra headers.
- bodyProducer (twisted.web.iweb.IBodyProducer|None):
+ method: HTTP method: GET/POST/etc
+ uri: Absolute URI to be retrieved
+ headers:
+ HTTP headers to send with the request, or None to send no extra headers.
+ bodyProducer:
An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or None if the request is to have
@@ -123,6 +132,9 @@ class MatrixFederationAgent:
# explicit port.
parsed_uri = urllib.parse.urlparse(uri)
+ # There must be a valid hostname.
+ assert parsed_uri.hostname
+
# If this is a matrix:// URI check if the server has delegated matrix
# traffic using well-known delegation.
#
@@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory:
"""Factory for MatrixHostnameEndpoint for parsing to an Agent.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: Optional[SrvResolver],
+ ):
self._reactor = reactor
self._tls_client_options_factory = tls_client_options_factory
@@ -203,15 +220,20 @@ class MatrixHostnameEndpoint:
resolution (i.e. via SRV). Does not check for well-known delegation.
Args:
- reactor (IReactor)
- tls_client_options_factory (ClientTLSOptionsFactory|None):
+ reactor: twisted reactor to use for underlying requests
+ tls_client_options_factory:
factory to use for fetching client tls options, or none to disable TLS.
- srv_resolver (SrvResolver): The SRV resolver to use
- parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting
- to connect to.
+ srv_resolver: The SRV resolver to use
+ parsed_uri: The parsed URI that we're wanting to connect to.
"""
- def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri):
+ def __init__(
+ self,
+ reactor: IReactorCore,
+ tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+ srv_resolver: SrvResolver,
+ parsed_uri: URI,
+ ):
self._reactor = reactor
self._parsed_uri = parsed_uri
@@ -231,13 +253,13 @@ class MatrixHostnameEndpoint:
self._srv_resolver = srv_resolver
- def connect(self, protocol_factory):
+ def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
"""Implements IStreamClientEndpoint interface
"""
return run_in_background(self._do_connect, protocol_factory)
- async def _do_connect(self, protocol_factory):
+ async def _do_connect(self, protocol_factory: IProtocolFactory) -> None:
first_exception = None
server_list = await self._resolve_server()
@@ -303,20 +325,20 @@ class MatrixHostnameEndpoint:
return [Server(host, 8448)]
-def _is_ip_literal(host):
+def _is_ip_literal(host: bytes) -> bool:
"""Test if the given host name is either an IPv4 or IPv6 literal.
Args:
- host (bytes)
+ host: The host name to check
Returns:
- bool
+ True if the hostname is an IP address literal.
"""
- host = host.decode("ascii")
+ host_str = host.decode("ascii")
try:
- IPAddress(host)
+ IPAddress(host_str)
return True
except AddrFormatError:
return False
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 1cc666fbf6..5e08ef1664 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
import random
import time
@@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple
import attr
from twisted.internet import defer
+from twisted.internet.interfaces import IReactorTime
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IAgent, IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -81,11 +81,11 @@ class WellKnownResolver:
def __init__(
self,
- reactor,
- agent,
- user_agent,
- well_known_cache=None,
- had_well_known_cache=None,
+ reactor: IReactorTime,
+ agent: IAgent,
+ user_agent: bytes,
+ well_known_cache: Optional[TTLCache] = None,
+ had_well_known_cache: Optional[TTLCache] = None,
):
self._reactor = reactor
self._clock = Clock(reactor)
@@ -127,7 +127,7 @@ class WellKnownResolver:
with Measure(self._clock, "get_well_known"):
result, cache_period = await self._fetch_well_known(
server_name
- ) # type: Tuple[Optional[bytes], float]
+ ) # type: Optional[bytes], float
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 7e17cdb73e..4e27f93b7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,8 +17,9 @@ import cgi
import logging
import random
import sys
-import urllib
+import urllib.parse
from io import BytesIO
+from typing import Callable, Dict, List, Optional, Tuple, Union
import attr
import treq
@@ -27,25 +28,27 @@ from prometheus_client import Counter
from signedjson.sign import sign_json
from zope.interface import implementer
-from twisted.internet import defer, protocol
+from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
-from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IResponse
+from twisted.web.iweb import IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
- Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
- SynapseError,
)
from synapse.http import QuieterFileBodyProducer
-from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
+from synapse.http.client import (
+ BlacklistingAgentWrapper,
+ IPBlacklistingResolver,
+ encode_query_args,
+ readBodyToFile,
+)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -54,6 +57,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
+from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -76,47 +80,44 @@ MAXINT = sys.maxsize
_next_id = 1
+QueryArgs = Dict[str, Union[str, List[str]]]
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
- method = attr.ib()
+ method = attr.ib(type=str)
"""HTTP method
- :type: str
"""
- path = attr.ib()
+ path = attr.ib(type=str)
"""HTTP path
- :type: str
"""
- destination = attr.ib()
+ destination = attr.ib(type=str)
"""The remote server to send the HTTP request to.
- :type: str"""
+ """
- json = attr.ib(default=None)
+ json = attr.ib(default=None, type=Optional[JsonDict])
"""JSON to send in the body.
- :type: dict|None
"""
- json_callback = attr.ib(default=None)
+ json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]])
"""A callback to generate the JSON.
- :type: func|None
"""
- query = attr.ib(default=None)
+ query = attr.ib(default=None, type=Optional[dict])
"""Query arguments.
- :type: dict|None
"""
- txn_id = attr.ib(default=None)
+ txn_id = attr.ib(default=None, type=Optional[str])
"""Unique ID for this request (for logging)
- :type: str|None
"""
uri = attr.ib(init=False, type=bytes)
"""The URI of this request
"""
- def __attrs_post_init__(self):
+ def __attrs_post_init__(self) -> None:
global _next_id
txn_id = "%s-O-%s" % (self.method, _next_id)
_next_id = (_next_id + 1) % (MAXINT - 1)
@@ -136,7 +137,7 @@ class MatrixFederationRequest:
)
object.__setattr__(self, "uri", uri)
- def get_json(self):
+ def get_json(self) -> Optional[JsonDict]:
if self.json_callback:
return self.json_callback()
return self.json
@@ -148,7 +149,7 @@ async def _handle_json_response(
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-):
+) -> JsonDict:
"""
Reads the JSON body of a response, with a timeout
@@ -160,7 +161,7 @@ async def _handle_json_response(
start_ms: Timestamp when request was made
Returns:
- dict: parsed JSON response
+ The parsed JSON response
"""
try:
check_content_type_is_json(response.headers)
@@ -250,9 +251,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
- self.reactor,
- ip_blacklist=hs.config.federation_ip_range_blacklist,
+ self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)
self.clock = hs.get_clock()
@@ -266,27 +265,29 @@ class MatrixFederationHttpClient:
self._cooperator = Cooperator(scheduler=schedule)
async def _send_request_with_optional_trailing_slash(
- self, request, try_trailing_slash_on_400=False, **send_request_args
- ):
+ self,
+ request: MatrixFederationRequest,
+ try_trailing_slash_on_400: bool = False,
+ **send_request_args
+ ) -> IResponse:
"""Wrapper for _send_request which can optionally retry the request
upon receiving a combination of a 400 HTTP response code and a
'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
due to #3622.
Args:
- request (MatrixFederationRequest): details of request to be sent
- try_trailing_slash_on_400 (bool): Whether on receiving a 400
+ request: details of request to be sent
+ try_trailing_slash_on_400: Whether on receiving a 400
'M_UNRECOGNIZED' from the server to retry the request with a
trailing slash appended to the request path.
- send_request_args (Dict): A dictionary of arguments to pass to
- `_send_request()`.
+ send_request_args: A dictionary of arguments to pass to `_send_request()`.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
(except 429).
Returns:
- Dict: Parsed JSON response body.
+ Parsed JSON response body.
"""
try:
response = await self._send_request(request, **send_request_args)
@@ -313,24 +314,26 @@ class MatrixFederationHttpClient:
async def _send_request(
self,
- request,
- retry_on_dns_fail=True,
- timeout=None,
- long_retries=False,
- ignore_backoff=False,
- backoff_on_404=False,
- ):
+ request: MatrixFederationRequest,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ long_retries: bool = False,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ ) -> IResponse:
"""
Sends a request to the given server.
Args:
- request (MatrixFederationRequest): details of request to be sent
+ request: details of request to be sent
+
+ retry_on_dns_fail: true if the request should be retied on DNS failures
- timeout (int|None): number of milliseconds to wait for the response headers
+ timeout: number of milliseconds to wait for the response headers
(including connecting to the server), *for each attempt*.
60s by default.
- long_retries (bool): whether to use the long retry algorithm.
+ long_retries: whether to use the long retry algorithm.
The regular retry algorithm makes 4 attempts, with intervals
[0.5s, 1s, 2s].
@@ -346,14 +349,13 @@ class MatrixFederationHttpClient:
NB: the long retry algorithm takes over 20 minutes to complete, with
a default timeout of 60s!
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): Back off if we get a 404
+ backoff_on_404: Back off if we get a 404
Returns:
- twisted.web.client.Response: resolves with the HTTP
- response object on success.
+ Resolves with the HTTP response object on success.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
@@ -404,7 +406,7 @@ class MatrixFederationHttpClient:
)
# Inject the span into the headers
- headers_dict = {}
+ headers_dict = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers_dict, request.destination)
headers_dict[b"User-Agent"] = [self.version_string_bytes]
@@ -435,7 +437,7 @@ class MatrixFederationHttpClient:
data = encode_canonical_json(json)
producer = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
- )
+ ) # type: Optional[IBodyProducer]
else:
producer = None
auth_headers = self.build_auth_headers(
@@ -524,14 +526,16 @@ class MatrixFederationHttpClient:
)
body = None
- e = HttpResponseException(response.code, response_phrase, body)
+ exc = HttpResponseException(
+ response.code, response_phrase, body
+ )
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
if response.code == 429:
- raise RequestSendFailed(e, can_retry=True) from e
+ raise RequestSendFailed(exc, can_retry=True) from exc
else:
- raise e
+ raise exc
break
except RequestSendFailed as e:
@@ -582,22 +586,27 @@ class MatrixFederationHttpClient:
return response
def build_auth_headers(
- self, destination, method, url_bytes, content=None, destination_is=None
- ):
+ self,
+ destination: Optional[bytes],
+ method: bytes,
+ url_bytes: bytes,
+ content: Optional[JsonDict] = None,
+ destination_is: Optional[bytes] = None,
+ ) -> List[bytes]:
"""
Builds the Authorization headers for a federation request
Args:
- destination (bytes|None): The destination homeserver of the request.
+ destination: The destination homeserver of the request.
May be None if the destination is an identity server, in which case
destination_is must be non-None.
- method (bytes): The HTTP method of the request
- url_bytes (bytes): The URI path of the request
- content (object): The body of the request
- destination_is (bytes): As 'destination', but if the destination is an
+ method: The HTTP method of the request
+ url_bytes: The URI path of the request
+ content: The body of the request
+ destination_is: As 'destination', but if the destination is an
identity server
Returns:
- list[bytes]: a list of headers to be added as "Authorization:" headers
+ A list of headers to be added as "Authorization:" headers
"""
request = {
"method": method.decode("ascii"),
@@ -629,33 +638,32 @@ class MatrixFederationHttpClient:
async def put_json(
self,
- destination,
- path,
- args={},
- data={},
- json_data_callback=None,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- backoff_on_404=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using PUT
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
- args (dict): query params
- data (dict): A dict containing the data that will be used as
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
+ args: query params
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- json_data_callback (callable): A callable returning the dict to
+ json_data_callback: A callable returning the dict to
use as the request body.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -663,19 +671,19 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- backoff_on_404 (bool): True if we should count a 404 response as
+ backoff_on_404: True if we should count a 404 response as
a failure of the server (and should therefore back off future
requests).
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -721,29 +729,28 @@ class MatrixFederationHttpClient:
async def post_json(
self,
- destination,
- path,
- data={},
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ data: Optional[JsonDict] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
""" Sends the specified json data using POST
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- data (dict): A dict containing the data that will be used as
+ data: A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -751,10 +758,10 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
dict|list: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -795,26 +802,25 @@ class MatrixFederationHttpClient:
async def get_json(
self,
- destination,
- path,
- args=None,
- retry_on_dns_fail=True,
- timeout=None,
- ignore_backoff=False,
- try_trailing_slash_on_400=False,
- ):
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ ) -> Union[JsonDict, list]:
""" GETs some json from the given host homeserver and path
Args:
- destination (str): The remote server to send the HTTP request
- to.
+ destination: The remote server to send the HTTP request to.
- path (str): The HTTP path.
+ path: The HTTP path.
- args (dict|None): A dictionary used to create query strings, defaults to
+ args: A dictionary used to create query strings, defaults to
None.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -822,14 +828,14 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
- try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
+ try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -870,24 +876,23 @@ class MatrixFederationHttpClient:
async def delete_json(
self,
- destination,
- path,
- long_retries=False,
- timeout=None,
- ignore_backoff=False,
- args={},
- ):
+ destination: str,
+ path: str,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ args: Optional[QueryArgs] = None,
+ ) -> Union[JsonDict, list]:
"""Send a DELETE request to the remote expecting some json response
Args:
- destination (str): The remote server to send the HTTP request
- to.
- path (str): The HTTP path.
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path.
- long_retries (bool): whether to use the long retry algorithm. See
+ long_retries: whether to use the long retry algorithm. See
docs on _send_request for details.
- timeout (int|None): number of milliseconds to wait for the response.
+ timeout: number of milliseconds to wait for the response.
self._default_timeout (60s) by default.
Note that we may make several attempts to send the request; this
@@ -895,12 +900,12 @@ class MatrixFederationHttpClient:
*each* attempt (including connection time) as well as the time spent
reading the response body after a 200 response.
- ignore_backoff (bool): true to ignore the historical backoff data and
+ ignore_backoff: true to ignore the historical backoff data and
try the request anyway.
- args (dict): query params
+ args: query params
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
+ Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Raises:
@@ -938,25 +943,25 @@ class MatrixFederationHttpClient:
async def get_file(
self,
- destination,
- path,
+ destination: str,
+ path: str,
output_stream,
- args={},
- retry_on_dns_fail=True,
- max_size=None,
- ignore_backoff=False,
- ):
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ max_size: Optional[int] = None,
+ ignore_backoff: bool = False,
+ ) -> Tuple[int, Dict[bytes, List[bytes]]]:
"""GETs a file from a given homeserver
Args:
- destination (str): The remote server to send the HTTP request to.
- path (str): The HTTP path to GET.
- output_stream (file): File to write the response body to.
- args (dict): Optional dictionary used to create the query string.
- ignore_backoff (bool): true to ignore the historical backoff data
+ destination: The remote server to send the HTTP request to.
+ path: The HTTP path to GET.
+ output_stream: File to write the response body to.
+ args: Optional dictionary used to create the query string.
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- tuple[int, dict]: Resolves with an (int,dict) tuple of
+ Resolves with an (int,dict) tuple of
the file length and a dict of the response headers.
Raises:
@@ -980,7 +985,7 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders())
try:
- d = _readBodyToFile(response, output_stream, max_size)
+ d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient:
return (length, headers)
-class _ReadBodyToFileProtocol(protocol.Protocol):
- def __init__(self, stream, deferred, max_size):
- self.stream = stream
- self.deferred = deferred
- self.length = 0
- self.max_size = max_size
-
- def dataReceived(self, data):
- self.stream.write(data)
- self.length += len(data)
- if self.max_size is not None and self.length >= self.max_size:
- self.deferred.errback(
- SynapseError(
- 502,
- "Requested file is too large > %r bytes" % (self.max_size,),
- Codes.TOO_LARGE,
- )
- )
- self.deferred = defer.Deferred()
- self.transport.loseConnection()
-
- def connectionLost(self, reason):
- if reason.check(ResponseDone):
- self.deferred.callback(self.length)
- else:
- self.deferred.errback(reason)
-
-
-def _readBodyToFile(response, stream, max_size):
- d = defer.Deferred()
- response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
- return d
-
-
def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers):
+def check_content_type_is_json(headers: Headers) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
is application/json.
Args:
- headers (twisted.web.http_headers.Headers): headers to check
+ headers: headers to check
Raises:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
@@ -1078,18 +1049,3 @@ def check_content_type_is_json(headers):
),
can_retry=False,
)
-
-
-def encode_query_args(args):
- if args is None:
- return b""
-
- encoded_args = {}
- for k, vs in args.items():
- if isinstance(vs, str):
- vs = [vs]
- encoded_args[k] = [v.encode("UTF-8") for v in vs]
-
- query_bytes = urllib.parse.urlencode(encoded_args, True)
-
- return query_bytes.encode("utf8")
diff --git a/synapse/http/server.py b/synapse/http/server.py
index c0919f8cb7..6a4e429a6c 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO
from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
import jinja2
-from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -94,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
pass
else:
respond_with_json(
- request,
- error_code,
- error_dict,
- send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
+ request, error_code, error_dict, send_cors=True,
)
@@ -290,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource):
code,
response_object,
send_cors=True,
- pretty_print=_request_user_agent_is_curl(request),
canonical_json=self.canonical_json,
)
@@ -587,7 +582,6 @@ def respond_with_json(
code: int,
json_object: Any,
send_cors: bool = False,
- pretty_print: bool = False,
canonical_json: bool = True,
):
"""Sends encoded JSON in response to the given request.
@@ -598,8 +592,6 @@ def respond_with_json(
json_object: The object to serialize to JSON.
send_cors: Whether to send Cross-Origin Resource Sharing headers
https://fetch.spec.whatwg.org/#http-cors-protocol
- pretty_print: Whether to include indentation and line-breaks in the
- resulting JSON bytes.
canonical_json: Whether to use the canonicaljson algorithm when encoding
the JSON bytes.
@@ -615,13 +607,10 @@ def respond_with_json(
)
return None
- if pretty_print:
- encoder = iterencode_pretty_printed_json
+ if canonical_json:
+ encoder = iterencode_canonical_json
else:
- if canonical_json:
- encoder = iterencode_canonical_json
- else:
- encoder = _encode_json_bytes
+ encoder = _encode_json_bytes
request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json")
@@ -685,7 +674,7 @@ def set_cors_headers(request: Request):
)
request.setHeader(
b"Access-Control-Allow-Headers",
- b"Origin, X-Requested-With, Content-Type, Accept, Authorization",
+ b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date",
)
@@ -759,11 +748,3 @@ def finish_request(request: Request):
request.finish()
except RuntimeError as e:
logger.info("Connection disconnected before response was written: %r", e)
-
-
-def _request_user_agent_is_curl(request: Request) -> bool:
- user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[])
- for user_agent in user_agents:
- if b"curl" in user_agent:
- return True
- return False
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 0142542852..72ab5750cc 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
+ self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed.
"""
# Create a requester object
- requester = create_requester(event_dict["sender"])
+ requester = create_requester(
+ event_dict["sender"], authenticated_entity=self._server_name
+ )
# Create and send the event
(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 793d0db2d9..eff0975b6a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -75,6 +75,7 @@ class HttpPusher:
self.failing_since = pusherdict["failing_since"]
self.timed_call = None
self._is_processing = False
+ self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
@@ -136,7 +137,11 @@ class HttpPusher:
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
await self._send_badge(badge)
def on_timer(self):
@@ -283,7 +288,11 @@ class HttpPusher:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
- badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
+ badge = await push_tools.get_badge_count(
+ self.hs.get_datastore(),
+ self.user_id,
+ group_by_room=self._group_unread_count_by_room,
+ )
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index d0145666bf..6e7c880dc0 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,12 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
+from synapse.storage.databases.main import DataStore
-async def get_badge_count(store, user_id):
+async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
@@ -34,9 +34,15 @@ async def get_badge_count(store, user_id):
room_id, user_id, last_unread_event_id
)
)
- # return one badge count per conversation, as count per
- # message is so noisy as to be almost useless
- badge += 1 if notifs["notify_count"] else 0
+ if notifs["notify_count"] == 0:
+ continue
+
+ if group_by_room:
+ # return one badge count per conversation
+ badge += 1
+ else:
+ # increment the badge count by the number of unread messages in the room
+ badge += notifs["notify_count"]
return badge
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 040f76bb53..c97e0df1f5 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -40,6 +40,10 @@ logger = logging.getLogger(__name__)
# Note that these both represent runtime dependencies (and the versions
# installed are checked at runtime).
#
+# Also note that we replicate these constraints in the Synapse Dockerfile while
+# pre-installing dependencies. If these constraints are updated here, the same
+# change should be made in the Dockerfile.
+#
# [1] https://pip.pypa.io/en/stable/reference/pip_install/#requirement-specifiers.
REQUIREMENTS = [
@@ -69,14 +73,7 @@ REQUIREMENTS = [
"msgpack>=0.5.2",
"phonenumbers>=8.2.0",
# we use GaugeHistogramMetric, which was added in prom-client 0.4.0.
- # prom-client has a history of breaking backwards compatibility between
- # minor versions (https://github.com/prometheus/client_python/issues/317),
- # so we also pin the minor version.
- #
- # Note that we replicate these constraints in the Synapse Dockerfile while
- # pre-installing dependencies. If these constraints are updated here, the
- # same change should be made in the Dockerfile.
- "prometheus_client>=0.4.0,<0.9.0",
+ "prometheus_client>=0.4.0",
# we use attr.validators.deep_iterable, which arrived in 19.1.0 (Note:
# Fedora 31 only has 19.1, so if we want to upgrade we should wait until 33
# is out in November.)
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index b4f4a68b5c..7a0dbb5b1a 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -254,20 +254,20 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
return 200, {}
-class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
+class ReplicationStoreRoomOnOutlierMembershipRestServlet(ReplicationEndpoint):
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
Request format:
- POST /_synapse/replication/store_room_on_invite/:room_id/:txn_id
+ POST /_synapse/replication/store_room_on_outlier_membership/:room_id/:txn_id
{
"room_version": "1",
}
"""
- NAME = "store_room_on_invite"
+ NAME = "store_room_on_outlier_membership"
PATH_ARGS = ("room_id",)
def __init__(self, hs):
@@ -282,7 +282,7 @@ class ReplicationStoreRoomOnInviteRestServlet(ReplicationEndpoint):
async def _handle_request(self, request, room_id):
content = parse_json_object_from_request(request)
room_version = KNOWN_ROOM_VERSIONS[content["room_version"]]
- await self.store.maybe_store_room_on_invite(room_id, room_version)
+ await self.store.maybe_store_room_on_outlier_membership(room_id, room_version)
return 200, {}
@@ -291,4 +291,4 @@ def register_servlets(hs, http_server):
ReplicationFederationSendEduRestServlet(hs).register(http_server)
ReplicationGetQueryRestServlet(hs).register(http_server)
ReplicationCleanRoomRestServlet(hs).register(http_server)
- ReplicationStoreRoomOnInviteRestServlet(hs).register(http_server)
+ ReplicationStoreRoomOnOutlierMembershipRestServlet(hs).register(http_server)
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index f0c37eaf5e..84e002f934 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
+
+from twisted.web.http import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
@@ -52,16 +53,23 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
self.clock = hs.get_clock()
@staticmethod
- async def _serialize_payload(
- requester, room_id, user_id, remote_room_hosts, content
- ):
+ async def _serialize_payload( # type: ignore
+ requester: Requester,
+ room_id: str,
+ user_id: str,
+ remote_room_hosts: List[str],
+ content: JsonDict,
+ ) -> JsonDict:
"""
Args:
- requester(Requester)
- room_id (str)
- user_id (str)
- remote_room_hosts (list[str]): Servers to try and join via
- content(dict): The event content to use for the join event
+ requester: The user making the request according to the access token
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ remote_room_hosts: Servers to try and join via
+ content: The event content to use for the join event
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"requester": requester.serialize(),
@@ -69,7 +77,9 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, room_id, user_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
remote_room_hosts = content["remote_room_hosts"]
@@ -118,14 +128,17 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
- ):
+ ) -> JsonDict:
"""
Args:
- invite_event_id: ID of the invite to be rejected
- txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ invite_event_id: The ID of the invite to be rejected.
+ txn_id: Optional transaction ID supplied by the client
+ requester: User making the rejection request, according to the access token
+ content: Additional content to include in the rejection event.
Normally an empty dict.
+
+ Returns:
+ A dict representing the payload of the request.
"""
return {
"txn_id": txn_id,
@@ -133,7 +146,9 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
"content": content,
}
- async def _handle_request(self, request, invite_event_id):
+ async def _handle_request( # type: ignore
+ self, request: Request, invite_event_id: str
+ ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
txn_id = content["txn_id"]
@@ -174,18 +189,25 @@ class ReplicationUserJoinedLeftRoomRestServlet(ReplicationEndpoint):
self.distributor = hs.get_distributor()
@staticmethod
- async def _serialize_payload(room_id, user_id, change):
+ async def _serialize_payload( # type: ignore
+ room_id: str, user_id: str, change: str
+ ) -> JsonDict:
"""
Args:
- room_id (str)
- user_id (str)
- change (str): "left"
+ room_id: The ID of the room.
+ user_id: The ID of the user.
+ change: "left"
+
+ Returns:
+ A dict representing the payload of the request.
"""
assert change == "left"
return {}
- def _handle_request(self, request, room_id, user_id, change):
+ def _handle_request( # type: ignore
+ self, request: Request, room_id: str, user_id: str, change: str
+ ) -> Tuple[int, JsonDict]:
logger.info("user membership change: %s in %s", user_id, room_id)
user = UserID.from_string(user_id)
diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py
index 2a4f7a1740..55ddebb4fe 100644
--- a/synapse/rest/admin/__init__.py
+++ b/synapse/rest/admin/__init__.py
@@ -21,11 +21,7 @@ import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.rest.admin._base import (
- admin_patterns,
- assert_requester_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.rest.admin.devices import (
DeleteDevicesRestServlet,
DeviceRestServlet,
@@ -61,6 +57,7 @@ from synapse.rest.admin.users import (
UserRestServletV2,
UsersRestServlet,
UsersRestServletV2,
+ UserTokenRestServlet,
WhoisRestServlet,
)
from synapse.types import RoomStreamToken
@@ -83,7 +80,7 @@ class VersionServlet(RestServlet):
class PurgeHistoryRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/purge_history/(?P<room_id>[^/]*)(/(?P<event_id>[^/]+))?"
)
@@ -168,9 +165,7 @@ class PurgeHistoryRestServlet(RestServlet):
class PurgeHistoryStatusRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns(
- "/purge_history_status/(?P<purge_id>[^/]+)"
- )
+ PATTERNS = admin_patterns("/purge_history_status/(?P<purge_id>[^/]+)")
def __init__(self, hs):
"""
@@ -223,6 +218,7 @@ def register_servlets(hs, http_server):
UserAdminServlet(hs).register(http_server)
UserMediaRestServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server)
+ UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server)
diff --git a/synapse/rest/admin/_base.py b/synapse/rest/admin/_base.py
index db9fea263a..e09234c644 100644
--- a/synapse/rest/admin/_base.py
+++ b/synapse/rest/admin/_base.py
@@ -22,28 +22,6 @@ from synapse.api.errors import AuthError
from synapse.types import UserID
-def historical_admin_path_patterns(path_regex):
- """Returns the list of patterns for an admin endpoint, including historical ones
-
- This is a backwards-compatibility hack. Previously, the Admin API was exposed at
- various paths under /_matrix/client. This function returns a list of patterns
- matching those paths (as well as the new one), so that existing scripts which rely
- on the endpoints being available there are not broken.
-
- Note that this should only be used for existing endpoints: new ones should just
- register for the /_synapse/admin path.
- """
- return [
- re.compile(prefix + path_regex)
- for prefix in (
- "^/_synapse/admin/v1",
- "^/_matrix/client/api/v1/admin",
- "^/_matrix/client/unstable/admin",
- "^/_matrix/client/r0/admin",
- )
- ]
-
-
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
diff --git a/synapse/rest/admin/groups.py b/synapse/rest/admin/groups.py
index 0b54ca09f4..d0c86b204a 100644
--- a/synapse/rest/admin/groups.py
+++ b/synapse/rest/admin/groups.py
@@ -16,10 +16,7 @@ import logging
from synapse.api.errors import SynapseError
from synapse.http.servlet import RestServlet
-from synapse.rest.admin._base import (
- assert_user_is_admin,
- historical_admin_path_patterns,
-)
+from synapse.rest.admin._base import admin_patterns, assert_user_is_admin
logger = logging.getLogger(__name__)
@@ -28,7 +25,7 @@ class DeleteGroupAdminRestServlet(RestServlet):
"""Allows deleting of local groups
"""
- PATTERNS = historical_admin_path_patterns("/delete_group/(?P<group_id>[^/]*)")
+ PATTERNS = admin_patterns("/delete_group/(?P<group_id>[^/]*)")
def __init__(self, hs):
self.group_server = hs.get_groups_server_handler()
diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py
index ba50cb876d..c82b4f87d6 100644
--- a/synapse/rest/admin/media.py
+++ b/synapse/rest/admin/media.py
@@ -22,7 +22,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
logger = logging.getLogger(__name__)
@@ -34,10 +33,10 @@ class QuarantineMediaInRoom(RestServlet):
"""
PATTERNS = (
- historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+ admin_patterns("/room/(?P<room_id>[^/]+)/media/quarantine")
+
# This path kept around for legacy reasons
- historical_admin_path_patterns("/quarantine_media/(?P<room_id>[^/]+)")
+ admin_patterns("/quarantine_media/(?P<room_id>[^/]+)")
)
def __init__(self, hs):
@@ -63,9 +62,7 @@ class QuarantineMediaByUser(RestServlet):
this server.
"""
- PATTERNS = historical_admin_path_patterns(
- "/user/(?P<user_id>[^/]+)/media/quarantine"
- )
+ PATTERNS = admin_patterns("/user/(?P<user_id>[^/]+)/media/quarantine")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -90,7 +87,7 @@ class QuarantineMediaByID(RestServlet):
it via this server.
"""
- PATTERNS = historical_admin_path_patterns(
+ PATTERNS = admin_patterns(
"/media/quarantine/(?P<server_name>[^/]+)/(?P<media_id>[^/]+)"
)
@@ -116,7 +113,7 @@ class ListMediaInRoom(RestServlet):
"""Lists all of the media in a given room.
"""
- PATTERNS = historical_admin_path_patterns("/room/(?P<room_id>[^/]+)/media")
+ PATTERNS = admin_patterns("/room/(?P<room_id>[^/]+)/media")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -134,7 +131,7 @@ class ListMediaInRoom(RestServlet):
class PurgeMediaCacheRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/purge_media_cache")
+ PATTERNS = admin_patterns("/purge_media_cache")
def __init__(self, hs):
self.media_repository = hs.get_media_repository()
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index f5304ff43d..25f89e4685 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -29,7 +29,6 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -44,7 +43,7 @@ class ShutdownRoomRestServlet(RestServlet):
joined to the new room.
"""
- PATTERNS = historical_admin_path_patterns("/shutdown_room/(?P<room_id>[^/]+)")
+ PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
def __init__(self, hs):
self.hs = hs
@@ -71,14 +70,18 @@ class ShutdownRoomRestServlet(RestServlet):
class DeleteRoomRestServlet(RestServlet):
- """Delete a room from server. It is a combination and improvement of
- shut down and purge room.
+ """Delete a room from server.
+
+ It is a combination and improvement of shutdown and purge room.
+
Shuts down a room by removing all local users from the room.
Blocking all future invites and joins to the room is optional.
+
If desired any local aliases will be repointed to a new room
- created by `new_room_user_id` and kicked users will be auto
+ created by `new_room_user_id` and kicked users will be auto-
joined to the new room.
- It will remove all trace of a room from the database.
+
+ If 'purge' is true, it will remove all traces of a room from the database.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
@@ -111,6 +114,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
+ force_purge = content.get("force_purge", False)
+ if not isinstance(force_purge, bool):
+ raise SynapseError(
+ HTTPStatus.BAD_REQUEST,
+ "Param 'force_purge' must be a boolean, if given",
+ Codes.BAD_JSON,
+ )
+
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -122,7 +133,7 @@ class DeleteRoomRestServlet(RestServlet):
# Purge room
if purge:
- await self.pagination_handler.purge_room(room_id)
+ await self.pagination_handler.purge_room(room_id, force=force_purge)
return (200, ret)
@@ -309,7 +320,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
- fake_requester = create_requester(target_user)
+ fake_requester = create_requester(
+ target_user, authenticated_entity=requester.authenticated_entity
+ )
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index 3638e219f2..b0ff5e1ead 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
-from typing import Tuple
+from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -33,10 +33,13 @@ from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
- historical_admin_path_patterns,
)
+from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_GET_PUSHERS_ALLOWED_KEYS = {
@@ -52,7 +55,7 @@ _GET_PUSHERS_ALLOWED_KEYS = {
class UsersRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/users/(?P<user_id>[^/]*)$")
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
self.hs = hs
@@ -335,7 +338,7 @@ class UserRegisterServlet(RestServlet):
nonce to the time it was generated, in int seconds.
"""
- PATTERNS = historical_admin_path_patterns("/register")
+ PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
@@ -458,7 +461,14 @@ class UserRegisterServlet(RestServlet):
class WhoisRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/whois/(?P<user_id>[^/]*)")
+ path_regex = "/whois/(?P<user_id>[^/]*)$"
+ PATTERNS = (
+ admin_patterns(path_regex)
+ +
+ # URL for spec reason
+ # https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-admin-whois-userid
+ client_patterns("/admin" + path_regex, v1=True)
+ )
def __init__(self, hs):
self.hs = hs
@@ -482,7 +492,7 @@ class WhoisRestServlet(RestServlet):
class DeactivateAccountRestServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/deactivate/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/deactivate/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self._deactivate_account_handler = hs.get_deactivate_account_handler()
@@ -513,7 +523,7 @@ class DeactivateAccountRestServlet(RestServlet):
class AccountValidityRenewServlet(RestServlet):
- PATTERNS = historical_admin_path_patterns("/account_validity/validity$")
+ PATTERNS = admin_patterns("/account_validity/validity$")
def __init__(self, hs):
"""
@@ -556,9 +566,7 @@ class ResetPasswordRestServlet(RestServlet):
200 OK with empty object if success otherwise an error.
"""
- PATTERNS = historical_admin_path_patterns(
- "/reset_password/(?P<target_user_id>[^/]*)"
- )
+ PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -600,7 +608,7 @@ class SearchUsersRestServlet(RestServlet):
200 OK with json object {list[dict[str, Any]], count} or empty object.
"""
- PATTERNS = historical_admin_path_patterns("/search_users/(?P<target_user_id>[^/]*)")
+ PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
self.hs = hs
@@ -828,3 +836,52 @@ class UserMediaRestServlet(RestServlet):
ret["next_token"] = start + len(media)
return 200, ret
+
+
+class UserTokenRestServlet(RestServlet):
+ """An admin API for logging in as a user.
+
+ Example:
+
+ POST /_synapse/admin/v1/users/@test:example.com/login
+ {}
+
+ 200 OK
+ {
+ "access_token": "<some_token>"
+ }
+ """
+
+ PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
+
+ def __init__(self, hs: "HomeServer"):
+ self.hs = hs
+ self.store = hs.get_datastore()
+ self.auth = hs.get_auth()
+ self.auth_handler = hs.get_auth_handler()
+
+ async def on_POST(self, request, user_id):
+ requester = await self.auth.get_user_by_req(request)
+ await assert_user_is_admin(self.auth, requester.user)
+ auth_user = requester.user
+
+ if not self.hs.is_mine_id(user_id):
+ raise SynapseError(400, "Only local users can be logged in as")
+
+ body = parse_json_object_from_request(request, allow_empty_body=True)
+
+ valid_until_ms = body.get("valid_until_ms")
+ if valid_until_ms and not isinstance(valid_until_ms, int):
+ raise SynapseError(400, "'valid_until_ms' parameter must be an int")
+
+ if auth_user.to_string() == user_id:
+ raise SynapseError(400, "Cannot use admin API to login as self")
+
+ token = await self.auth_handler.get_access_token_for_user_id(
+ user_id=auth_user.to_string(),
+ device_id=None,
+ valid_until_ms=valid_until_ms,
+ puppets_user_id=user_id,
+ )
+
+ return 200, {"access_token": token}
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 94452fcbf5..d7ae148214 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,10 +19,6 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.handlers.auth import (
- convert_client_dict_legacy_fields_to_identifier,
- login_id_phone_to_thirdparty,
-)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -33,7 +29,6 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
-from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
@@ -78,11 +73,6 @@ class LoginRestServlet(RestServlet):
rate_hz=self.hs.config.rc_login_account.per_second,
burst_count=self.hs.config.rc_login_account.burst_count,
)
- self._failed_attempts_ratelimiter = Ratelimiter(
- clock=hs.get_clock(),
- rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
- burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
- )
def on_GET(self, request: SynapseRequest):
flows = []
@@ -140,27 +130,31 @@ class LoginRestServlet(RestServlet):
result["well_known"] = well_known_data
return 200, result
- def _get_qualified_user_id(self, identifier):
- if identifier["type"] != "m.id.user":
- raise SynapseError(400, "Unknown login identifier type")
- if "user" not in identifier:
- raise SynapseError(400, "User identifier is missing 'user' key")
-
- if identifier["user"].startswith("@"):
- return identifier["user"]
- else:
- return UserID(identifier["user"], self.hs.hostname).to_string()
-
async def _do_appservice_login(
self, login_submission: JsonDict, appservice: ApplicationService
):
- logger.info(
- "Got appservice login request with identifier: %r",
- login_submission.get("identifier"),
- )
+ identifier = login_submission.get("identifier")
+ logger.info("Got appservice login request with identifier: %r", identifier)
+
+ if not isinstance(identifier, dict):
+ raise SynapseError(
+ 400, "Invalid identifier in login submission", Codes.INVALID_PARAM
+ )
+
+ # this login flow only supports identifiers of type "m.id.user".
+ if identifier.get("type") != "m.id.user":
+ raise SynapseError(
+ 400, "Unknown login identifier type", Codes.INVALID_PARAM
+ )
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
- qualified_user_id = self._get_qualified_user_id(identifier)
+ user = identifier.get("user")
+ if not isinstance(user, str):
+ raise SynapseError(400, "Invalid user in identifier", Codes.INVALID_PARAM)
+
+ if user.startswith("@"):
+ qualified_user_id = user
+ else:
+ qualified_user_id = UserID(user, self.hs.hostname).to_string()
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
@@ -186,91 +180,9 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
- identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
-
- # convert phone type identifiers to generic threepids
- if identifier["type"] == "m.id.phone":
- identifier = login_id_phone_to_thirdparty(identifier)
-
- # convert threepid identifiers to user IDs
- if identifier["type"] == "m.id.thirdparty":
- address = identifier.get("address")
- medium = identifier.get("medium")
-
- if medium is None or address is None:
- raise SynapseError(400, "Invalid thirdparty identifier")
-
- # For emails, canonicalise the address.
- # We store all email addresses canonicalised in the DB.
- # (See add_threepid in synapse/handlers/auth.py)
- if medium == "email":
- try:
- address = canonicalise_email(address)
- except ValueError as e:
- raise SynapseError(400, str(e))
-
- # We also apply account rate limiting using the 3PID as a key, as
- # otherwise using 3PID bypasses the ratelimiting based on user ID.
- self._failed_attempts_ratelimiter.ratelimit((medium, address), update=False)
-
- # Check for login providers that support 3pid login types
- (
- canonical_user_id,
- callback_3pid,
- ) = await self.auth_handler.check_password_provider_3pid(
- medium, address, login_submission["password"]
- )
- if canonical_user_id:
- # Authentication through password provider and 3pid succeeded
-
- result = await self._complete_login(
- canonical_user_id, login_submission, callback_3pid
- )
- return result
-
- # No password providers were able to handle this 3pid
- # Check local store
- user_id = await self.hs.get_datastore().get_user_id_by_threepid(
- medium, address
- )
- if not user_id:
- logger.warning(
- "unknown 3pid identifier medium %s, address %r", medium, address
- )
- # We mark that we've failed to log in here, as
- # `check_password_provider_3pid` might have returned `None` due
- # to an incorrect password, rather than the account not
- # existing.
- #
- # If it returned None but the 3PID was bound then we won't hit
- # this code path, which is fine as then the per-user ratelimit
- # will kick in below.
- self._failed_attempts_ratelimiter.can_do_action((medium, address))
- raise LoginError(403, "", errcode=Codes.FORBIDDEN)
-
- identifier = {"type": "m.id.user", "user": user_id}
-
- # by this point, the identifier should be an m.id.user: if it's anything
- # else, we haven't understood it.
- qualified_user_id = self._get_qualified_user_id(identifier)
-
- # Check if we've hit the failed ratelimit (but don't update it)
- self._failed_attempts_ratelimiter.ratelimit(
- qualified_user_id.lower(), update=False
+ canonical_user_id, callback = await self.auth_handler.validate_login(
+ login_submission, ratelimit=True
)
-
- try:
- canonical_user_id, callback = await self.auth_handler.validate_login(
- identifier["user"], login_submission
- )
- except LoginError:
- # The user has failed to log in, so we need to update the rate
- # limiter. Using `can_do_action` avoids us raising a ratelimit
- # exception and masking the LoginError. The actual ratelimiting
- # should have happened above.
- self._failed_attempts_ratelimiter.can_do_action(qualified_user_id.lower())
- raise
-
result = await self._complete_login(
canonical_user_id, login_submission, callback
)
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 25d3cc6148..93c06afe27 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -18,7 +18,7 @@
import logging
import re
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional
from urllib import parse as urlparse
from synapse.api.constants import EventTypes, Membership
@@ -48,8 +48,7 @@ from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID,
from synapse.util import json_decoder
from synapse.util.stringutils import random_string
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
import synapse.server
logger = logging.getLogger(__name__)
diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py
index a54e1011f7..eebee44a44 100644
--- a/synapse/rest/client/v2_alpha/account.py
+++ b/synapse/rest/client/v2_alpha/account.py
@@ -115,7 +115,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND)
@@ -387,7 +387,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -466,7 +466,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index ea68114026..a89ae6ddf9 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -135,7 +135,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE)
@@ -214,7 +214,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
# comments for request_token_inhibit_3pid_errors.
# Also wait for some random amount of time between 100ms and 1s to make it
# look like we did something.
- await self.hs.clock.sleep(random.randint(1, 10) / 10)
+ await self.hs.get_clock().sleep(random.randint(1, 10) / 10)
return 200, {"sid": random_string(16)}
raise SynapseError(
diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py
index 2b84eb89c0..8e52e4cca4 100644
--- a/synapse/rest/client/v2_alpha/sync.py
+++ b/synapse/rest/client/v2_alpha/sync.py
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
)
with context:
sync_result = await self.sync_handler.wait_for_sync_for_user(
+ requester,
sync_config,
since_token=since_token,
timeout=timeout,
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index c16280f668..d8e8e48c1c 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -66,7 +66,7 @@ class LocalKey(Resource):
def __init__(self, hs):
self.config = hs.config
- self.clock = hs.clock
+ self.clock = hs.get_clock()
self.update_response_body(self.clock.time_msec())
Resource.__init__(self)
diff --git a/synapse/server.py b/synapse/server.py
index 21a232bbd9..b017e3489f 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -27,7 +27,8 @@ import logging
import os
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, TypeVar, cast
-import twisted
+import twisted.internet.base
+import twisted.internet.tcp
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@@ -89,6 +90,7 @@ from synapse.handlers.room_member import RoomMemberMasterHandler
from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.search import SearchHandler
from synapse.handlers.set_password import SetPasswordHandler
+from synapse.handlers.sso import SsoHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
@@ -145,7 +147,8 @@ def cache_in_self(builder: T) -> T:
"@cache_in_self can only be used on functions starting with `get_`"
)
- depname = builder.__name__[len("get_") :]
+ # get_attr -> _attr
+ depname = builder.__name__[len("get") :]
building = [False]
@@ -233,15 +236,6 @@ class HomeServer(metaclass=abc.ABCMeta):
self._instance_id = random_string(5)
self._instance_name = config.worker_name or "master"
- self.clock = Clock(reactor)
- self.distributor = Distributor()
-
- self.registration_ratelimiter = Ratelimiter(
- clock=self.clock,
- rate_hz=config.rc_registration.per_second,
- burst_count=config.rc_registration.burst_count,
- )
-
self.version_string = version_string
self.datastores = None # type: Optional[Databases]
@@ -299,8 +293,9 @@ class HomeServer(metaclass=abc.ABCMeta):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
+ @cache_in_self
def get_clock(self) -> Clock:
- return self.clock
+ return Clock(self._reactor)
def get_datastore(self) -> DataStore:
if not self.datastores:
@@ -317,11 +312,17 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_config(self) -> HomeServerConfig:
return self.config
+ @cache_in_self
def get_distributor(self) -> Distributor:
- return self.distributor
+ return Distributor()
+ @cache_in_self
def get_registration_ratelimiter(self) -> Ratelimiter:
- return self.registration_ratelimiter
+ return Ratelimiter(
+ clock=self.get_clock(),
+ rate_hz=self.config.rc_registration.per_second,
+ burst_count=self.config.rc_registration.burst_count,
+ )
@cache_in_self
def get_federation_client(self) -> FederationClient:
@@ -391,6 +392,10 @@ class HomeServer(metaclass=abc.ABCMeta):
return FollowerTypingHandler(self)
@cache_in_self
+ def get_sso_handler(self) -> SsoHandler:
+ return SsoHandler(self)
+
+ @cache_in_self
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
@@ -681,7 +686,7 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_federation_ratelimiter(self) -> FederationRateLimiter:
- return FederationRateLimiter(self.clock, config=self.config.rc_federation)
+ return FederationRateLimiter(self.get_clock(), config=self.config.rc_federation)
@cache_in_self
def get_module_api(self) -> ModuleApi:
diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py
index d464c75c03..100dbd5e2c 100644
--- a/synapse/server_notices/server_notices_manager.py
+++ b/synapse/server_notices/server_notices_manager.py
@@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id
+ self._server_name = hs.hostname
self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid
@@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid
- requester = create_requester(system_mxid)
+ requester = create_requester(
+ system_mxid, authenticated_entity=self._server_name
+ )
logger.info("Sending server notice to %s", user_id)
@@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url,
}
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
info, _ = await self._room_creation_handler.create_room(
requester,
config={
@@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to.
"""
- requester = create_requester(self.server_notices_mxid)
+ requester = create_requester(
+ self.server_notices_mxid, authenticated_entity=self._server_name
+ )
# Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them.
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index ecfc6717b3..5d668aadb2 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -314,6 +314,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
for table in (
"event_auth",
"event_edges",
+ "event_json",
"event_push_actions_staging",
"event_reference_hashes",
"event_relations",
@@ -340,7 +341,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"destination_rooms",
"event_backward_extremities",
"event_forward_extremities",
- "event_json",
"event_push_actions",
"event_search",
"events",
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index ca7917c989..1e7949a323 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -278,7 +278,8 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
async def get_linearized_receipts_for_all_rooms(
self, to_key: int, from_key: Optional[int] = None
) -> Dict[str, JsonDict]:
- """Get receipts for all rooms between two stream_ids.
+ """Get receipts for all rooms between two stream_ids, up
+ to a limit of the latest 100 read receipts.
Args:
to_key: Max stream id to fetch receipts upto.
@@ -294,12 +295,16 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [from_key, to_key])
else:
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ?
+ ORDER BY stream_id DESC
+ LIMIT 100
"""
txn.execute(sql, [to_key])
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index e5d07ce72a..fedb8a6c26 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1110,6 +1110,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
token: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
+ puppets_user_id: Optional[str] = None,
) -> int:
"""Adds an access token for the given user.
@@ -1133,6 +1134,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
"token": token,
"device_id": device_id,
"valid_until_ms": valid_until_ms,
+ "puppets_user_id": puppets_user_id,
},
desc="add_access_token_to_user",
)
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index dc0c4b5499..6b89db15c9 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1240,13 +1240,15 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
logger.error("store_room with room_id=%s failed: %s", room_id, e)
raise StoreError(500, "Problem creating room.")
- async def maybe_store_room_on_invite(self, room_id: str, room_version: RoomVersion):
+ async def maybe_store_room_on_outlier_membership(
+ self, room_id: str, room_version: RoomVersion
+ ):
"""
- When we receive an invite over federation, store the version of the room if we
- don't already know the room version.
+ When we receive an invite or any other event over federation that may relate to a room
+ we are not in, store the version of the room if we don't already know the room version.
"""
await self.db_pool.simple_upsert(
- desc="maybe_store_room_on_invite",
+ desc="maybe_store_room_on_outlier_membership",
table="rooms",
keyvalues={"room_id": room_id},
values={},
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 01d9dbb36f..dcdaf09682 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set
+from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Set, Tuple
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
@@ -350,6 +350,38 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ async def get_local_current_membership_for_user_in_room(
+ self, user_id: str, room_id: str
+ ) -> Tuple[Optional[str], Optional[str]]:
+ """Retrieve the current local membership state and event ID for a user in a room.
+
+ Args:
+ user_id: The ID of the user.
+ room_id: The ID of the room.
+
+ Returns:
+ A tuple of (membership_type, event_id). Both will be None if a
+ room_id/user_id pair is not found.
+ """
+ # Paranoia check.
+ if not self.hs.is_mine_id(user_id):
+ raise Exception(
+ "Cannot call 'get_local_current_membership_for_user_in_room' on "
+ "non-local user %s" % (user_id,),
+ )
+
+ results_dict = await self.db_pool.simple_select_one(
+ "local_current_membership",
+ {"room_id": room_id, "user_id": user_id},
+ ("membership", "event_id"),
+ allow_none=True,
+ desc="get_local_current_membership_for_user_in_room",
+ )
+ if not results_dict:
+ return None, None
+
+ return results_dict.get("membership"), results_dict.get("event_id")
+
@cached(max_entries=500000, iterable=True)
async def get_rooms_for_user_with_stream_ordering(
self, user_id: str
diff --git a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
index b64926e9c9..3275ae2b20 100644
--- a/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
+++ b/synapse/storage/databases/main/schema/delta/58/07add_method_to_thumbnail_constraint.sql.postgres
@@ -20,14 +20,14 @@
*/
-- add new index that includes method to local media
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('local_media_repository_thumbnails_method_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5807, 'local_media_repository_thumbnails_method_idx', '{}');
-- add new index that includes method to remote media
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'remote_media_repository_thumbnails_method_idx', '{}', 'local_media_repository_thumbnails_method_idx');
-- drop old index
-INSERT INTO background_updates (update_name, progress_json, depends_on) VALUES
- ('media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+ (5807, 'media_repository_drop_index_wo_method', '{}', 'remote_media_repository_thumbnails_method_idx');
diff --git a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
index cade5dcca8..fd733adf13 100644
--- a/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
+++ b/synapse/storage/databases/main/schema/delta/58/12room_stats.sql
@@ -28,5 +28,5 @@
-- functionality as the old one. This effectively restarts the background job
-- from the beginning, without running it twice in a row, supporting both
-- upgrade usecases.
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('populate_stats_process_rooms_2', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5812, 'populate_stats_process_rooms_2', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
index a2842687f1..e1a35be831 100644
--- a/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
+++ b/synapse/storage/databases/main/schema/delta/58/22users_have_local_media.sql
@@ -1,2 +1,2 @@
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('users_have_local_media', '{}');
\ No newline at end of file
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5822, 'users_have_local_media', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
index 61c558db77..75c3915a94 100644
--- a/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
+++ b/synapse/storage/databases/main/schema/delta/58/23e2e_cross_signing_keys_idx.sql
@@ -13,5 +13,5 @@
* limitations under the License.
*/
-INSERT INTO background_updates (update_name, progress_json) VALUES
- ('e2e_cross_signing_keys_idx', '{}');
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5823, 'e2e_cross_signing_keys_idx', '{}');
diff --git a/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
new file mode 100644
index 0000000000..8a39d54aed
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/24drop_event_json_index.sql
@@ -0,0 +1,19 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- this index is essentially redundant. The only time it was ever used was when purging
+-- rooms - and Synapse 1.24 will change that.
+
+DROP INDEX IF EXISTS event_json_room_id;
diff --git a/synapse/types.py b/synapse/types.py
index 66bb5bac8d..3ab6bdbe06 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -317,14 +317,14 @@ mxid_localpart_allowed_characters = set(
)
-def contains_invalid_mxid_characters(localpart):
+def contains_invalid_mxid_characters(localpart: str) -> bool:
"""Check for characters not allowed in an mxid or groupid localpart
Args:
- localpart (basestring): the localpart to be checked
+ localpart: the localpart to be checked
Returns:
- bool: True if there are any naughty characters
+ True if there are any naughty characters
"""
return any(c not in mxid_localpart_allowed_characters for c in localpart)
|