summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2021-03-09 15:23:55 +0000
committerRichard van der Hoff <richard@matrix.org>2021-03-09 15:23:55 +0000
commit56c0c711c169548a2a4cf4e1948a76f7974ec4f8 (patch)
treeb8e625040829cea105d37556b11aa1598828e107 /synapse
parentMerge remote-tracking branch 'origin/release-v1.29.0' into matrix-org-hotfixes (diff)
parentLink to the List user's media admin API from media Admin API docs (#9571) (diff)
downloadsynapse-56c0c711c169548a2a4cf4e1948a76f7974ec4f8.tar.xz
Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py41
-rw-r--r--synapse/appservice/api.py2
-rw-r--r--synapse/config/_base.py38
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/logger.py5
-rw-r--r--synapse/config/oidc_config.py89
-rw-r--r--synapse/config/server.py3
-rw-r--r--synapse/federation/federation_server.py25
-rw-r--r--synapse/federation/sender/transaction_manager.py11
-rw-r--r--synapse/handlers/acme.py4
-rw-r--r--synapse/handlers/auth.py68
-rw-r--r--synapse/handlers/initial_sync.py2
-rw-r--r--synapse/handlers/oidc_handler.py166
-rw-r--r--synapse/handlers/pagination.py2
-rw-r--r--synapse/handlers/register.py35
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/room_list.py4
-rw-r--r--synapse/handlers/sso.py3
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/http/client.py5
-rw-r--r--synapse/http/federation/matrix_federation_agent.py3
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/http/matrixfederationclient.py15
-rw-r--r--synapse/logging/context.py6
-rw-r--r--synapse/module_api/__init__.py31
-rw-r--r--synapse/replication/http/_base.py9
-rw-r--r--synapse/replication/tcp/redis.py2
-rw-r--r--synapse/rest/admin/purge_room_servlet.py15
-rw-r--r--synapse/rest/admin/server_notice_servlet.py23
-rw-r--r--synapse/rest/client/v1/login.py14
-rw-r--r--synapse/rest/media/v1/thumbnailer.py11
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/types.py16
-rw-r--r--synapse/util/async_helpers.py26
-rw-r--r--synapse/util/caches/response_cache.py10
-rw-r--r--synapse/util/macaroons.py89
37 files changed, 587 insertions, 202 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py

index 2c24d4ae03..56ca888862 100644 --- a/synapse/__init__.py +++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try: except ImportError: pass -__version__ = "1.29.0rc1" +__version__ = "1.29.0" if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): # We import here so that we don't have to install a bunch of deps when diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import StateMap, UserID from synapse.util.caches.lrucache import LruCache +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -408,7 +409,7 @@ class Auth: raise _InvalidMacaroonException() try: - user_id = self.get_user_id_from_macaroon(macaroon) + user_id = get_value_from_macaroon(macaroon, "user_id") guest = False for caveat in macaroon.caveats: @@ -416,7 +417,12 @@ class Auth: guest = True self.validate_macaroon(macaroon, rights, user_id=user_id) - except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError): + except ( + pymacaroons.exceptions.MacaroonException, + KeyError, + TypeError, + ValueError, + ): raise InvalidClientTokenError("Invalid macaroon passed.") if rights == "access": @@ -424,27 +430,6 @@ class Auth: return user_id, guest - def get_user_id_from_macaroon(self, macaroon): - """Retrieve the user_id given by the caveats on the macaroon. - - Does *not* validate the macaroon. - - Args: - macaroon (pymacaroons.Macaroon): The macaroon to validate - - Returns: - (str) user id - - Raises: - InvalidClientCredentialsError if there is no user_id caveat in the - macaroon - """ - user_prefix = "user_id = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(user_prefix): - return caveat.caveat_id[len(user_prefix) :] - raise InvalidClientTokenError("No user caveat in macaroon") - def validate_macaroon(self, macaroon, type_string, user_id): """ validate that a Macaroon is understood by and was signed by this server. @@ -465,21 +450,13 @@ class Auth: v.satisfy_exact("type = " + type_string) v.satisfy_exact("user_id = %s" % user_id) v.satisfy_exact("guest = true") - v.satisfy_general(self._verify_expiry) + satisfy_expiry(v, self.clock.time_msec) # access_tokens include a nonce for uniqueness: any value is acceptable v.satisfy_general(lambda c: c.startswith("nonce = ")) v.verify(macaroon, self._macaroon_secret_key) - def _verify_expiry(self, caveat): - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self.hs.get_clock().time_msec() - return now < expiry - def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService: token = self.get_access_token_from_request(request) service = self.store.get_app_service_by_token(token) diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 93c2aabcca..9d3bbe3b8b 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient): self.clock = hs.get_clock() self.protocol_meta_cache = ResponseCache( - hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS + hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS ) # type: ResponseCache[Tuple[str, str]] async def query_user(self, service, user_id): diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 4026966711..ba9cd63cf2 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py
@@ -212,9 +212,8 @@ class Config: @classmethod def read_file(cls, file_path, config_name): - cls.check_file(file_path, config_name) - with open(file_path) as file_stream: - return file_stream.read() + """Deprecated: call read_file directly""" + return read_file(file_path, (config_name,)) def read_template(self, filename: str) -> jinja2.Template: """Load a template file from disk. @@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): return self._get_instance(key) -__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"] +def read_file(file_path: Any, config_path: Iterable[str]) -> str: + """Check the given file exists, and read it into a string + + If it does not, emit an error indicating the problem + + Args: + file_path: the file to be read + config_path: where in the configuration file_path came from, so that a useful + error can be emitted if it does not exist. + Returns: + content of the file. + Raises: + ConfigError if there is a problem reading the file. + """ + if not isinstance(file_path, str): + raise ConfigError("%r is not a string", config_path) + + try: + os.stat(file_path) + with open(file_path) as file_stream: + return file_stream.read() + except OSError as e: + raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e + + +__all__ = [ + "Config", + "RootConfig", + "ShardedWorkerHandlingConfig", + "RoutableShardedWorkerHandlingConfig", + "read_file", +] diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index db16c86f50..e896fd34e2 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig: class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig): def get_instance(self, key: str) -> str: ... + +def read_file(file_path: Any, config_path: Iterable[str]) -> str: ... diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index e56cf846f5..999aecce5c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading from string import Template import yaml +from zope.interface import implementer from twisted.logger import ( + ILogObserver, LogBeginner, STDLibLogObserver, eventAsText, @@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> threadlocal = threading.local() - def _log(event): + @implementer(ILogObserver) + def _log(event: dict) -> None: if "log_text" in event: if event["log_text"].startswith("DNSDatagramProtocol starting on "): return diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594befc..7f5e449eb2 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@ # limitations under the License. from collections import Counter -from typing import Iterable, Optional, Tuple, Type +from typing import Iterable, Mapping, Optional, Tuple, Type import attr @@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module from synapse.util.stringutils import parse_and_validate_mxc_uri -from ._base import Config, ConfigError +from ._base import Config, ConfigError, read_file DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider" @@ -97,7 +97,26 @@ class OIDCConfig(Config): # # client_id: Required. oauth2 client id to use. # - # client_secret: Required. oauth2 client secret to use. + # client_secret: oauth2 client secret to use. May be omitted if + # client_secret_jwt_key is given, or if client_auth_method is 'none'. + # + # client_secret_jwt_key: Alternative to client_secret: details of a key used + # to create a JSON Web Token to be used as an OAuth2 client secret. If + # given, must be a dictionary with the following properties: + # + # key: a pem-encoded signing key. Must be a suitable key for the + # algorithm specified. Required unless 'key_file' is given. + # + # key_file: the path to file containing a pem-encoded signing key file. + # Required unless 'key' is given. + # + # jwt_header: a dictionary giving properties to include in the JWT + # header. Must include the key 'alg', giving the algorithm used to + # sign the JWT, such as "ES256", using the JWA identifiers in + # RFC7518. + # + # jwt_payload: an optional dictionary giving properties to include in + # the JWT payload. Normally this should include an 'iss' key. # # client_auth_method: auth method to use when exchanging the token. Valid # values are 'client_secret_basic' (default), 'client_secret_post' and @@ -240,7 +259,7 @@ class OIDCConfig(Config): # jsonschema definition of the configuration settings for an oidc identity provider OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", - "required": ["issuer", "client_id", "client_secret"], + "required": ["issuer", "client_id"], "properties": { "idp_id": { "type": "string", @@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "issuer": {"type": "string"}, "client_id": {"type": "string"}, "client_secret": {"type": "string"}, + "client_secret_jwt_key": { + "type": "object", + "required": ["jwt_header"], + "oneOf": [ + {"required": ["key"]}, + {"required": ["key_file"]}, + ], + "properties": { + "key": {"type": "string"}, + "key_file": {"type": "string"}, + "jwt_header": { + "type": "object", + "required": ["alg"], + "properties": { + "alg": {"type": "string"}, + }, + "additionalProperties": {"type": "string"}, + }, + "jwt_payload": { + "type": "object", + "additionalProperties": {"type": "string"}, + }, + }, + }, "client_auth_method": { "type": "string", # the following list is the same as the keys of @@ -404,6 +447,20 @@ def _parse_oidc_config_dict( "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) ) from e + client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key") + client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey] + if client_secret_jwt_key_config is not None: + keyfile = client_secret_jwt_key_config.get("key_file") + if keyfile: + key = read_file(keyfile, config_path + ("client_secret_jwt_key",)) + else: + key = client_secret_jwt_key_config["key"] + client_secret_jwt_key = OidcProviderClientSecretJwtKey( + key=key, + jwt_header=client_secret_jwt_key_config["jwt_header"], + jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}), + ) + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), @@ -412,7 +469,8 @@ def _parse_oidc_config_dict( discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], - client_secret=oidc_config["client_secret"], + client_secret=oidc_config.get("client_secret"), + client_secret_jwt_key=client_secret_jwt_key, client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), scopes=oidc_config.get("scopes", ["openid"]), authorization_endpoint=oidc_config.get("authorization_endpoint"), @@ -428,6 +486,18 @@ def _parse_oidc_config_dict( @attr.s(slots=True, frozen=True) +class OidcProviderClientSecretJwtKey: + # a pem-encoded signing key + key = attr.ib(type=str) + + # properties to include in the JWT header + jwt_header = attr.ib(type=Mapping[str, str]) + + # properties to include in the JWT payload. + jwt_payload = attr.ib(type=Mapping[str, str]) + + +@attr.s(slots=True, frozen=True) class OidcProviderConfig: # a unique identifier for this identity provider. Used in the 'user_external_ids' # table, as well as the query/path parameter used in the login protocol. @@ -452,8 +522,13 @@ class OidcProviderConfig: # oauth2 client id to use client_id = attr.ib(type=str) - # oauth2 client secret to use - client_secret = attr.ib(type=str) + # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate + # a secret. + client_secret = attr.ib(type=Optional[str]) + + # key to use to construct a JWT to use as a client secret. May be `None` if + # `client_secret` is set. + client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey]) # auth method to use when exchanging the token. # Valid values are 'client_secret_basic', 'client_secret_post' and diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2afca36e7d..5f8910b6e1 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py
@@ -841,8 +841,7 @@ class ServerConfig(Config): # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. Defaults to # 'false'. Note that profile data is also available via the federation - # API, so this setting is of limited value if federation is enabled on - # the server. + # API, unless allow_profile_lookup_over_federation is set to false. # #require_auth_for_profile_requests: true diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 93aa199119..f8e368f81b 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from typing import ( Awaitable, Callable, Dict, + Iterable, List, Optional, Tuple, @@ -90,16 +91,15 @@ pdu_process_time = Histogram( "Time taken to process an event", ) - -last_pdu_age_metric = Gauge( - "synapse_federation_last_received_pdu_age", - "The age (in seconds) of the last PDU successfully received from the given domain", +last_pdu_ts_metric = Gauge( + "synapse_federation_last_received_pdu_time", + "The timestamp of the last PDU which was successfully received from the given domain", labelnames=("server_name",), ) class FederationServer(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self.auth = hs.get_auth() @@ -119,7 +119,7 @@ class FederationServer(FederationBase): # We cache results for transaction with the same ID self._transaction_resp_cache = ResponseCache( - hs, "fed_txn_handler", timeout_ms=30000 + hs.get_clock(), "fed_txn_handler", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self.transaction_actions = TransactionActions(self.store) @@ -129,10 +129,10 @@ class FederationServer(FederationBase): # We cache responses to state queries, as they take a while and often # come in waves. self._state_resp_cache = ResponseCache( - hs, "state_resp", timeout_ms=30000 + hs.get_clock(), "state_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._state_ids_resp_cache = ResponseCache( - hs, "state_ids_resp", timeout_ms=30000 + hs.get_clock(), "state_ids_resp", timeout_ms=30000 ) # type: ResponseCache[Tuple[str, str]] self._federation_metrics_domains = ( @@ -361,7 +361,7 @@ class FederationServer(FederationBase): logger.error( "Failed to handle PDU %s", event_id, - exc_info=(f.type, f.value, f.getTracebackObject()), + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore ) await concurrently_execute( @@ -369,8 +369,7 @@ class FederationServer(FederationBase): ) if newest_pdu_ts and origin in self._federation_metrics_domains: - newest_pdu_age = self._clock.time_msec() - newest_pdu_ts - last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000) + last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000) return pdu_results @@ -455,7 +454,9 @@ class FederationServer(FederationBase): self, room_id: str, event_id: str ) -> Dict[str, list]: if event_id: - pdus = await self.handler.get_state_for_pdu(room_id, event_id) + pdus = await self.handler.get_state_for_pdu( + room_id, event_id + ) # type: Iterable[EventBase] else: pdus = (await self.state.get_current_state(room_id)).values() diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff296c..2a9cd063c4 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -last_pdu_age_metric = Gauge( - "synapse_federation_last_sent_pdu_age", - "The age (in seconds) of the last PDU successfully sent to the given domain", +last_pdu_ts_metric = Gauge( + "synapse_federation_last_sent_pdu_time", + "The timestamp of the last PDU which was successfully sent to the given domain", labelnames=("server_name",), ) @@ -187,9 +187,8 @@ class TransactionManager: if success and pdus and destination in self._federation_metrics_domains: last_pdu = pdus[-1] - last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts - last_pdu_age_metric.labels(server_name=destination).set( - last_pdu_age / 1000 + last_pdu_ts_metric.labels(server_name=destination).set( + last_pdu.origin_server_ts / 1000 ) set_tag(tags.ERROR, not success) diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1ac..132be238dd 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler: "Listening for ACME requests on %s:%i", host, self.hs.config.acme_port ) try: - self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host) + self.reactor.listenTCP( + self.hs.config.acme_port, srv, backlog=50, interface=host + ) except twisted.internet.error.CannotListenError as e: check_bind_error(e, host, bind_addresses) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.threepids import canonicalise_email @@ -170,6 +171,16 @@ class SsoLoginExtraAttributes: extra_attributes = attr.ib(type=JsonDict) +@attr.s(slots=True, frozen=True) +class LoginTokenAttributes: + """Data we store in a short-term login token""" + + user_id = attr.ib(type=str) + + # the SSO Identity Provider that the user authenticated with, to get this token + auth_provider_id = attr.ib(type=str) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler): return None return user_id - async def validate_short_term_login_token_and_get_user_id(self, login_token: str): - auth_api = self.hs.get_auth() - user_id = None + async def validate_short_term_login_token( + self, login_token: str + ) -> LoginTokenAttributes: try: - macaroon = pymacaroons.Macaroon.deserialize(login_token) - user_id = auth_api.get_user_id_from_macaroon(macaroon) - auth_api.validate_macaroon(macaroon, "login", user_id) + res = self.macaroon_gen.verify_short_term_login_token(login_token) except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - await self.auth.check_auth_blocking(user_id) - return user_id + await self.auth.check_auth_blocking(res.user_id) + return res async def delete_access_token(self, access_token: str): """Invalidate a single access token @@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler): async def complete_sso_login( self, registered_user_id: str, + auth_provider_id: str, request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, @@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler): Args: registered_user_id: The registered user ID to complete SSO login for. + auth_provider_id: The id of the SSO Identity provider that was used for + login. This will be stored in the login token for future tracking in + prometheus metrics. request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. @@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler): self._complete_sso_login( registered_user_id, + auth_provider_id, request, client_redirect_url, extra_attributes, @@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler): def _complete_sso_login( self, registered_user_id: str, + auth_provider_id: str, request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, @@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler): # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( - registered_user_id + registered_user_id, auth_provider_id=auth_provider_id ) # Append the login token to the original redirect URL (i.e. with its query @@ -1569,15 +1584,48 @@ class MacaroonGenerator: return macaroon.serialize() def generate_short_term_login_token( - self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + self, + user_id: str, + auth_provider_id: str, + duration_in_ms: int = (2 * 60 * 1000), ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) + macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,)) return macaroon.serialize() + def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: + """Verify a short-term-login macaroon + + Checks that the given token is a valid, unexpired short-term-login token + minted by this server. + + Args: + token: the login token to verify + + Returns: + the user_id that this token is valid for + + Raises: + MacaroonVerificationFailedException if the verification failed + """ + macaroon = pymacaroons.Macaroon.deserialize(token) + user_id = get_value_from_macaroon(macaroon, "user_id") + auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") + + v = pymacaroons.Verifier() + v.satisfy_exact("gen = 1") + v.satisfy_exact("type = login") + v.satisfy_general(lambda c: c.startswith("user_id = ")) + v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) + satisfy_expiry(v, self.hs.get_clock().time_msec) + v.verify(macaroon, self.hs.config.key.macaroon_secret_key) + + return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id) + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 71a5076672..13f8152283 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler): self.clock = hs.get_clock() self.validator = EventValidator() self.snapshot_cache = ResponseCache( - hs, "initial_sync_cache" + hs.get_clock(), "initial_sync_cache" ) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]] self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..825fadb76f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,13 +15,13 @@ # limitations under the License. import inspect import logging -from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union from urllib.parse import urlencode import attr import pymacaroons from authlib.common.security import generate_token -from authlib.jose import JsonWebToken +from authlib.jose import JsonWebToken, jwt from authlib.oauth2.auth import ClientAuth from authlib.oauth2.rfc6749.parameters import prepare_grant_uri from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo @@ -35,13 +36,17 @@ from typing_extensions import TypedDict from twisted.web.client import readBody from synapse.config import ConfigError -from synapse.config.oidc_config import OidcProviderConfig +from synapse.config.oidc_config import ( + OidcProviderClientSecretJwtKey, + OidcProviderConfig, +) from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart -from synapse.util import json_decoder +from synapse.util import Clock, json_decoder from synapse.util.caches.cached_call import RetryOnExceptionCachedCall +from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry if TYPE_CHECKING: from synapse.server import HomeServer @@ -211,7 +216,7 @@ class OidcHandler: session_data = self._token_generator.verify_oidc_session_token( session, state ) - except (MacaroonDeserializationException, ValueError) as e: + except (MacaroonDeserializationException, KeyError) as e: logger.exception("Invalid session for OIDC callback") self._sso_handler.render_error(request, "invalid_session", str(e)) return @@ -275,9 +280,21 @@ class OidcProvider: self._scopes = provider.scopes self._user_profile_method = provider.user_profile_method + + client_secret = None # type: Union[None, str, JwtClientSecret] + if provider.client_secret: + client_secret = provider.client_secret + elif provider.client_secret_jwt_key: + client_secret = JwtClientSecret( + provider.client_secret_jwt_key, + provider.client_id, + provider.issuer, + hs.get_clock(), + ) + self._client_auth = ClientAuth( provider.client_id, - provider.client_secret, + client_secret, provider.client_auth_method, ) # type: ClientAuth self._client_auth_method = provider.client_auth_method @@ -745,7 +762,7 @@ class OidcProvider: idp_id=self.idp_id, nonce=nonce, client_redirect_url=client_redirect_url.decode(), - ui_auth_session_id=ui_auth_session_id, + ui_auth_session_id=ui_auth_session_id or "", ), ) @@ -976,6 +993,81 @@ class OidcProvider: return str(remote_user_id) +# number of seconds a newly-generated client secret should be valid for +CLIENT_SECRET_VALIDITY_SECONDS = 3600 + +# minimum remaining validity on a client secret before we should generate a new one +CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600 + + +class JwtClientSecret: + """A class which generates a new client secret on demand, based on a JWK + + This implementation is designed to comply with the requirements for Apple Sign in: + https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048 + + It looks like those requirements are based on https://tools.ietf.org/html/rfc7523, + but it's worth noting that we still put the generated secret in the "client_secret" + field (or rather, whereever client_auth_method puts it) rather than in a + client_assertion field in the body as that RFC seems to require. + """ + + def __init__( + self, + key: OidcProviderClientSecretJwtKey, + oauth_client_id: str, + oauth_issuer: str, + clock: Clock, + ): + self._key = key + self._oauth_client_id = oauth_client_id + self._oauth_issuer = oauth_issuer + self._clock = clock + self._cached_secret = b"" + self._cached_secret_replacement_time = 0 + + def __str__(self): + # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls + # encode_client_secret_basic, which calls "{}".format(secret), which ends up + # here. + return self._get_secret().decode("ascii") + + def __bytes__(self): + # if client_auth_method is client_secret_post, then ClientAuth.prepare calls + # encode_client_secret_post, which ends up here. + return self._get_secret() + + def _get_secret(self) -> bytes: + now = self._clock.time() + + # if we have enough validity on our existing secret, use it + if now < self._cached_secret_replacement_time: + return self._cached_secret + + issued_at = int(now) + expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS + + # we copy the configured header because jwt.encode modifies it. + header = dict(self._key.jwt_header) + + # see https://tools.ietf.org/html/rfc7523#section-3 + payload = { + "sub": self._oauth_client_id, + "aud": self._oauth_issuer, + "iat": issued_at, + "exp": expires_at, + **self._key.jwt_payload, + } + logger.info( + "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload + ) + self._cached_secret = jwt.encode(header, payload, self._key.key) + self._cached_secret_replacement_time = ( + expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS + ) + return self._cached_secret + + class OidcSessionTokenGenerator: """Methods for generating and checking OIDC Session cookies.""" @@ -1020,10 +1112,9 @@ class OidcSessionTokenGenerator: macaroon.add_first_party_caveat( "client_redirect_url = %s" % (session_data.client_redirect_url,) ) - if session_data.ui_auth_session_id: - macaroon.add_first_party_caveat( - "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) - ) + macaroon.add_first_party_caveat( + "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,) + ) now = self._clock.time_msec() expiry = now + duration_in_ms macaroon.add_first_party_caveat("time < %d" % (expiry,)) @@ -1046,7 +1137,7 @@ class OidcSessionTokenGenerator: The data extracted from the session cookie Raises: - ValueError if an expected caveat is missing from the macaroon. + KeyError if an expected caveat is missing from the macaroon. """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -1057,26 +1148,16 @@ class OidcSessionTokenGenerator: v.satisfy_general(lambda c: c.startswith("nonce = ")) v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) - # Sometimes there's a UI auth session ID, it seems to be OK to attempt - # to always satisfy this. v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = ")) - v.satisfy_general(self._verify_expiry) + satisfy_expiry(v, self._clock.time_msec) v.verify(macaroon, self._macaroon_secret_key) # Extract the session data from the token. - nonce = self._get_value_from_macaroon(macaroon, "nonce") - idp_id = self._get_value_from_macaroon(macaroon, "idp_id") - client_redirect_url = self._get_value_from_macaroon( - macaroon, "client_redirect_url" - ) - try: - ui_auth_session_id = self._get_value_from_macaroon( - macaroon, "ui_auth_session_id" - ) # type: Optional[str] - except ValueError: - ui_auth_session_id = None - + nonce = get_value_from_macaroon(macaroon, "nonce") + idp_id = get_value_from_macaroon(macaroon, "idp_id") + client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url") + ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id") return OidcSessionData( nonce=nonce, idp_id=idp_id, @@ -1084,33 +1165,6 @@ class OidcSessionTokenGenerator: ui_auth_session_id=ui_auth_session_id, ) - def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str: - """Extracts a caveat value from a macaroon token. - - Args: - macaroon: the token - key: the key of the caveat to extract - - Returns: - The extracted value - - Raises: - ValueError: if the caveat was not in the macaroon - """ - prefix = key + " = " - for caveat in macaroon.caveats: - if caveat.caveat_id.startswith(prefix): - return caveat.caveat_id[len(prefix) :] - raise ValueError("No %s caveat in macaroon" % (key,)) - - def _verify_expiry(self, caveat: str) -> bool: - prefix = "time < " - if not caveat.startswith(prefix): - return False - expiry = int(caveat[len(prefix) :]) - now = self._clock.time_msec() - return now < expiry - @attr.s(frozen=True, slots=True) class OidcSessionData: @@ -1125,8 +1179,8 @@ class OidcSessionData: # The URL the client gave when it initiated the flow. ("" if this is a UI Auth) client_redirect_url = attr.ib(type=str) - # The session ID of the ongoing UI Auth (None if this is a login) - ui_auth_session_id = attr.ib(type=Optional[str], default=None) + # The session ID of the ongoing UI Auth ("" if this is a login) + ui_auth_session_id = attr.ib(type=str) UserAttributeDict = TypedDict( diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 059064a4eb..66dc886c81 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py
@@ -285,7 +285,7 @@ class PaginationHandler: except Exception: f = Failure() logger.error( - "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) + "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore ) self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED finally: diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3cda89657e..b66f8756b8 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py
@@ -18,6 +18,8 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple +from prometheus_client import Counter + from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError @@ -41,6 +43,19 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +registration_counter = Counter( + "synapse_user_registrations_total", + "Number of new users registered (since restart)", + ["guest", "shadow_banned", "auth_provider"], +) + +login_counter = Counter( + "synapse_user_logins_total", + "Number of user logins (since restart)", + ["guest", "auth_provider"], +) + + class RegistrationHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -156,6 +171,7 @@ class RegistrationHandler(BaseHandler): bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, + auth_provider_id: Optional[str] = None, ) -> str: """Registers a new client on the server. @@ -181,8 +197,10 @@ class RegistrationHandler(BaseHandler): admin api, otherwise False. user_agent_ips: Tuples of IP addresses and user-agents used during the registration process. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: - The registere user_id. + The registered user_id. Raises: SynapseError if there was a problem registering. """ @@ -280,6 +298,12 @@ class RegistrationHandler(BaseHandler): # if user id is taken, just generate another fail_count += 1 + registration_counter.labels( + guest=make_guest, + shadow_banned=shadow_banned, + auth_provider=(auth_provider_id or ""), + ).inc() + if not self.hs.config.user_consent_at_registration: if not self.hs.config.auto_join_rooms_for_guests and make_guest: logger.info( @@ -638,6 +662,7 @@ class RegistrationHandler(BaseHandler): initial_display_name: Optional[str], is_guest: bool = False, is_appservice_ghost: bool = False, + auth_provider_id: Optional[str] = None, ) -> Tuple[str, str]: """Register a device for a user and generate an access token. @@ -648,7 +673,8 @@ class RegistrationHandler(BaseHandler): device_id: The device ID to check, or None to generate a new one. initial_display_name: An optional display name for the device. is_guest: Whether this is a guest account - + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: Tuple of device ID and access token """ @@ -687,6 +713,11 @@ class RegistrationHandler(BaseHandler): is_appservice_ghost=is_appservice_ghost, ) + login_counter.labels( + guest=is_guest, + auth_provider=(auth_provider_id or ""), + ).inc() + return (registered_device_id, access_token) async def post_registration_actions( diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a488df10d6..4b3d0d72e3 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler): # succession, only process the first attempt and return its result to # subsequent requests self._upgrade_response_cache = ResponseCache( - hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS + hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS ) # type: ResponseCache[Tuple[str, str]] self._server_notices_mxid = hs.config.server_notices_mxid diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 70522e40fa..8c5b60e6be 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py
@@ -45,10 +45,10 @@ class RoomListHandler(BaseHandler): self.enable_room_list_search = hs.config.enable_room_list_search self.response_cache = ResponseCache( - hs, "room_list" + hs.get_clock(), "room_list" ) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]] self.remote_response_cache = ResponseCache( - hs, "remote_room_list", timeout_ms=30 * 1000 + hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000 ) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]] async def get_local_public_room_list( diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..6ef459acff 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler: await self._auth_handler.complete_sso_login( user_id, + auth_provider_id, request, client_redirect_url, extra_login_attributes, @@ -605,6 +606,7 @@ class SsoHandler: default_display_name=attributes.display_name, bind_emails=attributes.emails, user_agent_ips=[(user_agent, ip_address)], + auth_provider_id=auth_provider_id, ) await self._store.record_user_external_id( @@ -886,6 +888,7 @@ class SsoHandler: await self._auth_handler.complete_sso_login( user_id, + session.auth_provider_id, request, session.client_redirect_url, session.extra_login_attributes, diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6c8e361402..a65299bd22 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py
@@ -245,7 +245,7 @@ class SyncHandler: self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.response_cache = ResponseCache( - hs, "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS + hs.get_clock(), "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS ) # type: ResponseCache[Tuple[Any, ...]] self.state = hs.get_state_handler() self.auth = hs.get_auth() diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f95..af34d583ad 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.types import ISynapseReactor from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred @@ -199,7 +200,7 @@ class _IPBlacklistingResolver: return r -@implementer(IReactorPluggableNameResolver) +@implementer(ISynapseReactor) class BlacklistingReactorWrapper: """ A Reactor wrapper which will prevent DNS resolution to blacklisted IP @@ -324,7 +325,7 @@ class SimpleHttpClient: # filters out blacklisted IP addresses, to prevent DNS rebinding. self.reactor = BlacklistingReactorWrapper( hs.get_reactor(), self._ip_whitelist, self._ip_blacklist - ) + ) # type: ISynapseReactor else: self.reactor = hs.get_reactor() diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c08..5935a125fd 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.types import ISynapseReactor from synapse.util import Clock logger = logging.getLogger(__name__) @@ -68,7 +69,7 @@ class MatrixFederationAgent: def __init__( self, - reactor: IReactorCore, + reactor: ISynapseReactor, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, ip_blacklist: IPSet, diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 4def7d7633..ecd63e6596 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py
@@ -322,7 +322,8 @@ def _cache_period_from_headers( def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} - for hdr in headers.getRawHeaders(b"cache-control", []): + cache_control_headers = headers.getRawHeaders(b"cache-control") or [] + for hdr in cache_control_headers: for directive in hdr.split(b","): splits = [x.strip() for x in directive.split(b"=", 1)] k = splits[0].lower() diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714ea..5f01ebd3d4 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import ( start_active_span, tags, ) -from synapse.types import JsonDict +from synapse.types import ISynapseReactor, JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -237,14 +237,14 @@ class MatrixFederationHttpClient: # addresses, to prevent DNS rebinding. self.reactor = BlacklistingReactorWrapper( hs.get_reactor(), None, hs.config.federation_ip_range_blacklist - ) + ) # type: ISynapseReactor user_agent = hs.version_string if hs.config.user_agent_suffix: user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) user_agent = user_agent.encode("ascii") - self.agent = MatrixFederationAgent( + federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, user_agent, @@ -254,7 +254,7 @@ class MatrixFederationHttpClient: # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, + federation_agent, ip_blacklist=hs.config.federation_ip_range_blacklist, ) @@ -534,9 +534,10 @@ class MatrixFederationHttpClient: response.code, response_phrase, body ) - # Retry if the error is a 429 (Too Many Requests), - # otherwise just raise a standard HttpResponseException - if response.code == 429: + # Retry if the error is a 5xx or a 429 (Too Many + # Requests), otherwise just raise a standard + # `HttpResponseException` + if 500 <= response.code < 600 or response.code == 429: raise RequestSendFailed(exc, can_retry=True) from exc else: raise exc diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 78e27bfb00..1a7ea4fa96 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -669,7 +669,7 @@ def preserve_fn(f): return g -def run_in_background(f, *args, **kwargs): +def run_in_background(f, *args, **kwargs) -> defer.Deferred: """Calls a function, ensuring that the current context is restored after return from the function, and that the sentinel context is set once the deferred returned by the function completes. @@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs): if isinstance(res, types.CoroutineType): res = defer.ensureDeferred(res) + # At this point we should have a Deferred, if not then f was a synchronous + # function, wrap it in a Deferred for consistency. if not isinstance(res, defer.Deferred): - return res + return defer.succeed(res) if res.called and not res.paused: # The function should have maintained the logcontext, so we can diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi: ) def generate_short_term_login_token( - self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + self, + user_id: str, + duration_in_ms: int = (2 * 60 * 1000), + auth_provider_id: str = "", ) -> str: - """Generate a login token suitable for m.login.token authentication""" + """Generate a login token suitable for m.login.token authentication + + Args: + user_id: gives the ID of the user that the token is for + + duration_in_ms: the time that the token will be valid for + + auth_provider_id: the ID of the SSO IdP that the user used to authenticate + to get this token, if any. This is encoded in the token so that + /login can report stats on number of successful logins by IdP. + """ return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, duration_in_ms + user_id, + auth_provider_id, + duration_in_ms, ) @defer.inlineCallbacks @@ -276,6 +291,7 @@ class ModuleApi: """ self._auth_handler._complete_sso_login( registered_user_id, + "<unknown>", request, client_redirect_url, ) @@ -286,6 +302,7 @@ class ModuleApi: request: SynapseRequest, client_redirect_url: str, new_user: bool = False, + auth_provider_id: str = "<unknown>", ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that @@ -299,9 +316,15 @@ class ModuleApi: redirect them directly if whitelisted). new_user: set to true to use wording for the consent appropriate to a user who has just registered. + auth_provider_id: the ID of the SSO IdP which was used to log in. This + is used to track counts of sucessful logins by IdP. """ await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url, new_user=new_user + registered_user_id, + auth_provider_id, + request, + client_redirect_url, + new_user=new_user, ) @defer.inlineCallbacks diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 8a3f113e76..b7aa0c280f 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py
@@ -18,7 +18,7 @@ import logging import re import urllib from inspect import signature -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING, Dict, List, Tuple from prometheus_client import Counter, Gauge @@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import random_string +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) _pending_outgoing_requests = Gauge( @@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): CACHE = True RETRY_ON_TIMEOUT = True - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): if self.CACHE: self.response_cache = ResponseCache( - hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000 + hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000 ) # type: ResponseCache[str] # We reserve `instance_name` as a parameter to sending requests, so we diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf53..7560706b4b 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py
@@ -328,6 +328,6 @@ def lazyConnection( factory.continueTrying = reconnect reactor = hs.get_reactor() - reactor.connectTCP(host, port, factory, 30) + reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None) return factory.handler diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d44e..49966ee3e0 100644 --- a/synapse/rest/admin/purge_room_servlet.py +++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Tuple + from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer class PurgeRoomServlet(RestServlet): @@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet): PATTERNS = admin_patterns("/purge_room$") - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.pagination_handler = hs.get_pagination_handler() - async def on_POST(self, request): + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644 --- a/synapse/rest/admin/server_notice_servlet.py +++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import TYPE_CHECKING, Optional, Tuple + from synapse.api.constants import EventTypes from synapse.api.errors import SynapseError +from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, assert_params_in_dict, parse_json_object_from_request, ) +from synapse.http.site import SynapseRequest from synapse.rest.admin import assert_requester_is_admin from synapse.rest.admin._base import admin_patterns from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import UserID +from synapse.types import JsonDict, UserID + +if TYPE_CHECKING: + from synapse.server import HomeServer class SendServerNoticeServlet(RestServlet): @@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet): } """ - def __init__(self, hs): - """ - Args: - hs (synapse.server.HomeServer): server - """ + def __init__(self, hs: "HomeServer"): self.hs = hs self.auth = hs.get_auth() self.txns = HttpTransactionCache(hs) self.snm = hs.get_server_notices_manager() - def register(self, json_resource): + def register(self, json_resource: HttpServer): PATTERN = "/send_server_notice" json_resource.register_paths( "POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__ @@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet): self.__class__.__name__, ) - async def on_POST(self, request, txn_id=None): + async def on_POST( + self, request: SynapseRequest, txn_id: Optional[str] = None + ) -> Tuple[int, JsonDict]: await assert_requester_is_admin(self.auth, request) body = parse_json_object_from_request(request) assert_params_in_dict(body, ("user_id", "content")) @@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet): return 200, {"event_id": event.event_id} - def on_PUT(self, request, txn_id): + def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]: return self.txns.fetch_or_execute_request( request, self.on_POST, request, txn_id ) diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet): callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ratelimit: bool = True, + auth_provider_id: Optional[str] = None, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to actually login them in (e.g. create devices). This gets called on @@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet): create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. ratelimit: Whether to ratelimit the login request. + auth_provider_id: The SSO IdP the user used, if any (just used for the + prometheus metrics). Returns: result: Dictionary of account information after successful login. @@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet): device_id = login_submission.get("device_id") initial_display_name = login_submission.get("initial_device_display_name") device_id, access_token = await self.registration_handler.register_device( - user_id, device_id, initial_display_name + user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id ) result = { @@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet): """ token = login_submission["token"] auth_handler = self.auth_handler - user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( - token - ) + res = await auth_handler.validate_short_term_login_token(token) return await self._complete_login( - user_id, login_submission, self.auth_handler._sso_login_callback + res.user_id, + login_submission, + self.auth_handler._sso_login_callback, + auth_provider_id=res.auth_provider_id, ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 07903e4017..988f52c78f 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py
@@ -96,9 +96,14 @@ class Thumbnailer: def _resize(self, width: int, height: int) -> Image: # 1-bit or 8-bit color palette images need converting to RGB # otherwise they will be scaled using nearest neighbour which - # looks awful - if self.image.mode in ["1", "P"]: - self.image = self.image.convert("RGB") + # looks awful. + # + # If the image has transparency, use RGBA instead. + if self.image.mode in ["1", "L", "P"]: + mode = "RGB" + if self.image.info.get("transparency", None) is not None: + mode = "RGBA" + self.image = self.image.convert(mode) return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72e7..369cc88026 100644 --- a/synapse/server.py +++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import ( cast, ) -import twisted.internet.base import twisted.internet.tcp from twisted.internet import defer from twisted.mail.smtp import sendmail @@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import ( from synapse.state import StateHandler, StateResolutionHandler from synapse.storage import Databases, DataStore, Storage from synapse.streams.events import EventSources -from synapse.types import DomainSpecificString +from synapse.types import DomainSpecificString, ISynapseReactor from synapse.util import Clock from synapse.util.distributor import Distributor from synapse.util.ratelimitutils import FederationRateLimiter @@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta): for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP: getattr(self, "get_" + i + "_handler")() - def get_reactor(self) -> twisted.internet.base.ReactorBase: + def get_reactor(self) -> ISynapseReactor: """ Fetch the Twisted reactor in use by this HomeServer. """ diff --git a/synapse/types.py b/synapse/types.py
index 721343f0b5..0216d213c7 100644 --- a/synapse/types.py +++ b/synapse/types.py
@@ -35,6 +35,14 @@ from typing import ( import attr from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 +from zope.interface import Interface + +from twisted.internet.interfaces import ( + IReactorCore, + IReactorPluggableNameResolver, + IReactorTCP, + IReactorTime, +) from synapse.api.errors import Codes, SynapseError from synapse.util.stringutils import parse_and_validate_server_name @@ -67,6 +75,14 @@ MutableStateMap = MutableMapping[StateKey, T] JsonDict = Dict[str, Any] +# Note that this seems to require inheriting *directly* from Interface in order +# for mypy-zope to realize it is an interface. +class ISynapseReactor( + IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface +): + """The interfaces necessary for Synapse to function.""" + + class Requester( namedtuple( "Requester", diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 719e35b78d..f33c115844 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py
@@ -76,11 +76,16 @@ class ObservableDeferred: def callback(r): object.__setattr__(self, "_result", (True, r)) while self._observers: + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().callback(r) - except Exception: - pass + observer.callback(r) + except Exception as e: + logger.exception( + "%r threw an exception on .callback(%r), ignoring...", + observer, + r, + exc_info=e, + ) return r def errback(f): @@ -90,11 +95,16 @@ class ObservableDeferred: # traces when we `await` on one of the observer deferreds. f.value.__failure__ = f + observer = self._observers.pop() try: - # TODO: Handle errors here. - self._observers.pop().errback(f) - except Exception: - pass + observer.errback(f) + except Exception as e: + logger.exception( + "%r threw an exception on .errback(%r), ignoring...", + observer, + f, + exc_info=e, + ) if consumeErrors: return None diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..46ea8e0964 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar +from typing import Any, Callable, Dict, Generic, Optional, TypeVar from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background +from synapse.util import Clock from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches import register_cache -if TYPE_CHECKING: - from synapse.app.homeserver import HomeServer - logger = logging.getLogger(__name__) T = TypeVar("T") @@ -37,11 +35,11 @@ class ResponseCache(Generic[T]): used rather than trying to compute a new response. """ - def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): + def __init__(self, clock: Clock, name: str, timeout_ms: float = 0): # Requests that haven't finished yet. self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] - self.clock = hs.get_clock() + self.clock = clock self.timeout_sec = timeout_ms / 1000.0 self._name = name diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py new file mode 100644
index 0000000000..12cdd53327 --- /dev/null +++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for manipulating macaroons""" + +from typing import Callable, Optional + +import pymacaroons +from pymacaroons.exceptions import MacaroonVerificationFailedException + + +def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: + """Extracts a caveat value from a macaroon token. + + Checks that there is exactly one caveat of the form "key = <val>" in the macaroon, + and returns the extracted value. + + Args: + macaroon: the token + key: the key of the caveat to extract + + Returns: + The extracted value + + Raises: + MacaroonVerificationFailedException: if there are conflicting values for the + caveat in the macaroon, or if the caveat was not found in the macaroon. + """ + prefix = key + " = " + result = None # type: Optional[str] + for caveat in macaroon.caveats: + if not caveat.caveat_id.startswith(prefix): + continue + + val = caveat.caveat_id[len(prefix) :] + + if result is None: + # first time we found this caveat: record the value + result = val + elif val != result: + # on subsequent occurrences, raise if the value is different. + raise MacaroonVerificationFailedException( + "Conflicting values for caveat " + key + ) + + if result is not None: + return result + + # If the caveat is not there, we raise a MacaroonVerificationFailedException. + # Note that it is insecure to generate a macaroon without all the caveats you + # might need (because there is nothing stopping people from adding extra caveats), + # so if the caveat isn't there, something odd must be going on. + raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,)) + + +def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None: + """Make a macaroon verifier which accepts 'time' caveats + + Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to + the given macaroon verifier. + + Args: + v: the macaroon verifier + get_time_ms: a callable which will return the timestamp after which the caveat + should be considered expired. Normally the current time. + """ + + def verify_expiry_caveat(caveat: str): + time_msec = get_time_ms() + prefix = "time < " + if not caveat.startswith(prefix): + return False + expiry = int(caveat[len(prefix) :]) + return time_msec < expiry + + v.satisfy_general(verify_expiry_caveat)