diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index fb0a615930..ba9cd63cf2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,7 +20,6 @@ import errno
import os
from collections import OrderedDict
from hashlib import sha256
-from io import open as io_open
from textwrap import dedent
from typing import Any, Iterable, List, MutableMapping, Optional, Union
@@ -213,9 +212,8 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
- cls.check_file(file_path, config_name)
- with io_open(file_path, encoding="utf-8") as file_stream:
- return file_stream.read()
+ """Deprecated: call read_file directly"""
+ return read_file(file_path, (config_name,))
def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
@@ -895,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
return self._get_instance(key)
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+ """Check the given file exists, and read it into a string
+
+ If it does not, emit an error indicating the problem
+
+ Args:
+ file_path: the file to be read
+ config_path: where in the configuration file_path came from, so that a useful
+ error can be emitted if it does not exist.
+ Returns:
+ content of the file.
+ Raises:
+ ConfigError if there is a problem reading the file.
+ """
+ if not isinstance(file_path, str):
+ raise ConfigError("%r is not a string", config_path)
+
+ try:
+ os.stat(file_path)
+ with open(file_path) as file_stream:
+ return file_stream.read()
+ except OSError as e:
+ raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+ "Config",
+ "RootConfig",
+ "ShardedWorkerHandlingConfig",
+ "RoutableShardedWorkerHandlingConfig",
+ "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 5e1c9147a8..ddec356a07 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -154,3 +154,5 @@ class ShardedWorkerHandlingConfig:
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594befc..7f5e449eb2 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import Counter
-from typing import Iterable, Optional, Tuple, Type
+from typing import Iterable, Mapping, Optional, Tuple, Type
import attr
@@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@@ -97,7 +97,26 @@ class OIDCConfig(Config):
#
# client_id: Required. oauth2 client id to use.
#
- # client_secret: Required. oauth2 client secret to use.
+ # client_secret: oauth2 client secret to use. May be omitted if
+ # client_secret_jwt_key is given, or if client_auth_method is 'none'.
+ #
+ # client_secret_jwt_key: Alternative to client_secret: details of a key used
+ # to create a JSON Web Token to be used as an OAuth2 client secret. If
+ # given, must be a dictionary with the following properties:
+ #
+ # key: a pem-encoded signing key. Must be a suitable key for the
+ # algorithm specified. Required unless 'key_file' is given.
+ #
+ # key_file: the path to file containing a pem-encoded signing key file.
+ # Required unless 'key' is given.
+ #
+ # jwt_header: a dictionary giving properties to include in the JWT
+ # header. Must include the key 'alg', giving the algorithm used to
+ # sign the JWT, such as "ES256", using the JWA identifiers in
+ # RFC7518.
+ #
+ # jwt_payload: an optional dictionary giving properties to include in
+ # the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -240,7 +259,7 @@ class OIDCConfig(Config):
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
- "required": ["issuer", "client_id", "client_secret"],
+ "required": ["issuer", "client_id"],
"properties": {
"idp_id": {
"type": "string",
@@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
+ "client_secret_jwt_key": {
+ "type": "object",
+ "required": ["jwt_header"],
+ "oneOf": [
+ {"required": ["key"]},
+ {"required": ["key_file"]},
+ ],
+ "properties": {
+ "key": {"type": "string"},
+ "key_file": {"type": "string"},
+ "jwt_header": {
+ "type": "object",
+ "required": ["alg"],
+ "properties": {
+ "alg": {"type": "string"},
+ },
+ "additionalProperties": {"type": "string"},
+ },
+ "jwt_payload": {
+ "type": "object",
+ "additionalProperties": {"type": "string"},
+ },
+ },
+ },
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
@@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
+ client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+ client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
+ if client_secret_jwt_key_config is not None:
+ keyfile = client_secret_jwt_key_config.get("key_file")
+ if keyfile:
+ key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+ else:
+ key = client_secret_jwt_key_config["key"]
+ client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+ key=key,
+ jwt_header=client_secret_jwt_key_config["jwt_header"],
+ jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+ )
+
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
- client_secret=oidc_config["client_secret"],
+ client_secret=oidc_config.get("client_secret"),
+ client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
@@ -428,6 +486,18 @@ def _parse_oidc_config_dict(
@attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+ # a pem-encoded signing key
+ key = attr.ib(type=str)
+
+ # properties to include in the JWT header
+ jwt_header = attr.ib(type=Mapping[str, str])
+
+ # properties to include in the JWT payload.
+ jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
@@ -452,8 +522,13 @@ class OidcProviderConfig:
# oauth2 client id to use
client_id = attr.ib(type=str)
- # oauth2 client secret to use
- client_secret = attr.ib(type=str)
+ # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+ # a secret.
+ client_secret = attr.ib(type=Optional[str])
+
+ # key to use to construct a JWT to use as a client secret. May be `None` if
+ # `client_secret` is set.
+ client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index b4a74390cc..825fadb76f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -35,12 +36,15 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.config.oidc_config import OidcProviderConfig
+from synapse.config.oidc_config import (
+ OidcProviderClientSecretJwtKey,
+ OidcProviderConfig,
+)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@@ -276,9 +280,21 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
+
+ client_secret = None # type: Union[None, str, JwtClientSecret]
+ if provider.client_secret:
+ client_secret = provider.client_secret
+ elif provider.client_secret_jwt_key:
+ client_secret = JwtClientSecret(
+ provider.client_secret_jwt_key,
+ provider.client_id,
+ provider.issuer,
+ hs.get_clock(),
+ )
+
self._client_auth = ClientAuth(
provider.client_id,
- provider.client_secret,
+ client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
@@ -977,6 +993,81 @@ class OidcProvider:
return str(remote_user_id)
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+ """A class which generates a new client secret on demand, based on a JWK
+
+ This implementation is designed to comply with the requirements for Apple Sign in:
+ https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+ It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+ but it's worth noting that we still put the generated secret in the "client_secret"
+ field (or rather, whereever client_auth_method puts it) rather than in a
+ client_assertion field in the body as that RFC seems to require.
+ """
+
+ def __init__(
+ self,
+ key: OidcProviderClientSecretJwtKey,
+ oauth_client_id: str,
+ oauth_issuer: str,
+ clock: Clock,
+ ):
+ self._key = key
+ self._oauth_client_id = oauth_client_id
+ self._oauth_issuer = oauth_issuer
+ self._clock = clock
+ self._cached_secret = b""
+ self._cached_secret_replacement_time = 0
+
+ def __str__(self):
+ # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+ # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+ # here.
+ return self._get_secret().decode("ascii")
+
+ def __bytes__(self):
+ # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+ # encode_client_secret_post, which ends up here.
+ return self._get_secret()
+
+ def _get_secret(self) -> bytes:
+ now = self._clock.time()
+
+ # if we have enough validity on our existing secret, use it
+ if now < self._cached_secret_replacement_time:
+ return self._cached_secret
+
+ issued_at = int(now)
+ expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+ # we copy the configured header because jwt.encode modifies it.
+ header = dict(self._key.jwt_header)
+
+ # see https://tools.ietf.org/html/rfc7523#section-3
+ payload = {
+ "sub": self._oauth_client_id,
+ "aud": self._oauth_issuer,
+ "iat": issued_at,
+ "exp": expires_at,
+ **self._key.jwt_payload,
+ }
+ logger.info(
+ "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+ )
+ self._cached_secret = jwt.encode(header, payload, self._key.key)
+ self._cached_secret_replacement_time = (
+ expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+ )
+ return self._cached_secret
+
+
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index da6866addf..5f01ebd3d4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -534,9 +534,10 @@ class MatrixFederationHttpClient:
response.code, response_phrase, body
)
- # Retry if the error is a 429 (Too Many Requests),
- # otherwise just raise a standard HttpResponseException
- if response.code == 429:
+ # Retry if the error is a 5xx or a 429 (Too Many
+ # Requests), otherwise just raise a standard
+ # `HttpResponseException`
+ if 500 <= response.code < 600 or response.code == 429:
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise exc
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index cb6b1f8a0c..78367ea58d 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -135,6 +135,11 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
self._chain_cover_index,
)
+ self.db_pool.updates.register_background_update_handler(
+ "purged_chain_cover",
+ self._purged_chain_cover_index,
+ )
+
async def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
max_stream_id = progress["max_stream_id_exclusive"]
@@ -932,3 +937,77 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
processed_count=count,
finished_room_map=finished_rooms,
)
+
+ async def _purged_chain_cover_index(self, progress: dict, batch_size: int) -> int:
+ """
+ A background updates that iterates over the chain cover and deletes the
+ chain cover for events that have been purged.
+
+ This may be due to fully purging a room or via setting a retention policy.
+ """
+ current_event_id = progress.get("current_event_id", "")
+
+ def purged_chain_cover_txn(txn) -> int:
+ # The event ID from events will be null if the chain ID / sequence
+ # number points to a purged event.
+ sql = """
+ SELECT event_id, chain_id, sequence_number, e.event_id IS NOT NULL
+ FROM event_auth_chains
+ LEFT JOIN events AS e USING (event_id)
+ WHERE event_id > ? ORDER BY event_auth_chains.event_id ASC LIMIT ?
+ """
+ txn.execute(sql, (current_event_id, batch_size))
+
+ rows = txn.fetchall()
+ if not rows:
+ return 0
+
+ # The event IDs and chain IDs / sequence numbers where the event has
+ # been purged.
+ unreferenced_event_ids = []
+ unreferenced_chain_id_tuples = []
+ event_id = ""
+ for event_id, chain_id, sequence_number, has_event in rows:
+ if not has_event:
+ unreferenced_event_ids.append((event_id,))
+ unreferenced_chain_id_tuples.append((chain_id, sequence_number))
+
+ # Delete the unreferenced auth chains from event_auth_chain_links and
+ # event_auth_chains.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chains WHERE event_id = ?
+ """,
+ unreferenced_event_ids,
+ )
+ # We should also delete matching target_*, but there is no index on
+ # target_chain_id. Hopefully any purged events are due to a room
+ # being fully purged and they will be removed from the origin_*
+ # searches.
+ txn.executemany(
+ """
+ DELETE FROM event_auth_chain_links WHERE
+ origin_chain_id = ? AND origin_sequence_number = ?
+ """,
+ unreferenced_chain_id_tuples,
+ )
+
+ progress = {
+ "current_event_id": event_id,
+ }
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn, "purged_chain_cover", progress
+ )
+
+ return len(rows)
+
+ result = await self.db_pool.runInteraction(
+ "_purged_chain_cover_index",
+ purged_chain_cover_txn,
+ )
+
+ if not result:
+ await self.db_pool.updates._end_background_update("purged_chain_cover")
+
+ return result
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index 0836e4af49..41f4fe7f95 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -331,13 +331,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
- (origin_chain_id = ? AND origin_sequence_number = ?) OR
- (target_chain_id = ? AND target_sequence_number = ?)
+ origin_chain_id = ? AND origin_sequence_number = ?
""",
- (
- (chain_id, seq_num, chain_id, seq_num)
- for (chain_id, seq_num) in referenced_chain_id_tuples
- ),
+ referenced_chain_id_tuples,
)
# Now we delete tables which lack an index on room_id but have one on event_id
diff --git a/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
new file mode 100644
index 0000000000..87cb1f3cfd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/10delete_purged_chain_cover.sql
@@ -0,0 +1,17 @@
+/* Copyright 2021 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5910, 'purged_chain_cover', '{}');
|