diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2c24d4ae03..56ca888862 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.29.0rc1"
+__version__ = "1.29.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 89e62b0e36..968cf6f174 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -408,7 +409,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -416,7 +417,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ except (
+ pymacaroons.exceptions.MacaroonException,
+ KeyError,
+ TypeError,
+ ValueError,
+ ):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -424,27 +430,6 @@ class Auth:
return user_id, guest
- def get_user_id_from_macaroon(self, macaroon):
- """Retrieve the user_id given by the caveats on the macaroon.
-
- Does *not* validate the macaroon.
-
- Args:
- macaroon (pymacaroons.Macaroon): The macaroon to validate
-
- Returns:
- (str) user id
-
- Raises:
- InvalidClientCredentialsError if there is no user_id caveat in the
- macaroon
- """
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix) :]
- raise InvalidClientTokenError("No user caveat in macaroon")
-
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -465,21 +450,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
- def _verify_expiry(self, caveat):
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.hs.get_clock().time_msec()
- return now < expiry
-
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 93c2aabcca..9d3bbe3b8b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
- hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 4026966711..ba9cd63cf2 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -212,9 +212,8 @@ class Config:
@classmethod
def read_file(cls, file_path, config_name):
- cls.check_file(file_path, config_name)
- with open(file_path) as file_stream:
- return file_stream.read()
+ """Deprecated: call read_file directly"""
+ return read_file(file_path, (config_name,))
def read_template(self, filename: str) -> jinja2.Template:
"""Load a template file from disk.
@@ -894,4 +893,35 @@ class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
return self._get_instance(key)
-__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+def read_file(file_path: Any, config_path: Iterable[str]) -> str:
+ """Check the given file exists, and read it into a string
+
+ If it does not, emit an error indicating the problem
+
+ Args:
+ file_path: the file to be read
+ config_path: where in the configuration file_path came from, so that a useful
+ error can be emitted if it does not exist.
+ Returns:
+ content of the file.
+ Raises:
+ ConfigError if there is a problem reading the file.
+ """
+ if not isinstance(file_path, str):
+ raise ConfigError("%r is not a string", config_path)
+
+ try:
+ os.stat(file_path)
+ with open(file_path) as file_stream:
+ return file_stream.read()
+ except OSError as e:
+ raise ConfigError("Error accessing file %r" % (file_path,), config_path) from e
+
+
+__all__ = [
+ "Config",
+ "RootConfig",
+ "ShardedWorkerHandlingConfig",
+ "RoutableShardedWorkerHandlingConfig",
+ "read_file",
+]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index db16c86f50..e896fd34e2 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -152,3 +152,5 @@ class ShardedWorkerHandlingConfig:
class RoutableShardedWorkerHandlingConfig(ShardedWorkerHandlingConfig):
def get_instance(self, key: str) -> str: ...
+
+def read_file(file_path: Any, config_path: Iterable[str]) -> str: ...
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index e56cf846f5..999aecce5c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -21,8 +21,10 @@ import threading
from string import Template
import yaml
+from zope.interface import implementer
from twisted.logger import (
+ ILogObserver,
LogBeginner,
STDLibLogObserver,
eventAsText,
@@ -227,7 +229,8 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
threadlocal = threading.local()
- def _log(event):
+ @implementer(ILogObserver)
+ def _log(event: dict) -> None:
if "log_text" in event:
if event["log_text"].startswith("DNSDatagramProtocol starting on "):
return
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index a27594befc..7f5e449eb2 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,7 +15,7 @@
# limitations under the License.
from collections import Counter
-from typing import Iterable, Optional, Tuple, Type
+from typing import Iterable, Mapping, Optional, Tuple, Type
import attr
@@ -25,7 +25,7 @@ from synapse.types import Collection, JsonDict
from synapse.util.module_loader import load_module
from synapse.util.stringutils import parse_and_validate_mxc_uri
-from ._base import Config, ConfigError
+from ._base import Config, ConfigError, read_file
DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.oidc_handler.JinjaOidcMappingProvider"
@@ -97,7 +97,26 @@ class OIDCConfig(Config):
#
# client_id: Required. oauth2 client id to use.
#
- # client_secret: Required. oauth2 client secret to use.
+ # client_secret: oauth2 client secret to use. May be omitted if
+ # client_secret_jwt_key is given, or if client_auth_method is 'none'.
+ #
+ # client_secret_jwt_key: Alternative to client_secret: details of a key used
+ # to create a JSON Web Token to be used as an OAuth2 client secret. If
+ # given, must be a dictionary with the following properties:
+ #
+ # key: a pem-encoded signing key. Must be a suitable key for the
+ # algorithm specified. Required unless 'key_file' is given.
+ #
+ # key_file: the path to file containing a pem-encoded signing key file.
+ # Required unless 'key' is given.
+ #
+ # jwt_header: a dictionary giving properties to include in the JWT
+ # header. Must include the key 'alg', giving the algorithm used to
+ # sign the JWT, such as "ES256", using the JWA identifiers in
+ # RFC7518.
+ #
+ # jwt_payload: an optional dictionary giving properties to include in
+ # the JWT payload. Normally this should include an 'iss' key.
#
# client_auth_method: auth method to use when exchanging the token. Valid
# values are 'client_secret_basic' (default), 'client_secret_post' and
@@ -240,7 +259,7 @@ class OIDCConfig(Config):
# jsonschema definition of the configuration settings for an oidc identity provider
OIDC_PROVIDER_CONFIG_SCHEMA = {
"type": "object",
- "required": ["issuer", "client_id", "client_secret"],
+ "required": ["issuer", "client_id"],
"properties": {
"idp_id": {
"type": "string",
@@ -262,6 +281,30 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
"issuer": {"type": "string"},
"client_id": {"type": "string"},
"client_secret": {"type": "string"},
+ "client_secret_jwt_key": {
+ "type": "object",
+ "required": ["jwt_header"],
+ "oneOf": [
+ {"required": ["key"]},
+ {"required": ["key_file"]},
+ ],
+ "properties": {
+ "key": {"type": "string"},
+ "key_file": {"type": "string"},
+ "jwt_header": {
+ "type": "object",
+ "required": ["alg"],
+ "properties": {
+ "alg": {"type": "string"},
+ },
+ "additionalProperties": {"type": "string"},
+ },
+ "jwt_payload": {
+ "type": "object",
+ "additionalProperties": {"type": "string"},
+ },
+ },
+ },
"client_auth_method": {
"type": "string",
# the following list is the same as the keys of
@@ -404,6 +447,20 @@ def _parse_oidc_config_dict(
"idp_icon must be a valid MXC URI", config_path + ("idp_icon",)
) from e
+ client_secret_jwt_key_config = oidc_config.get("client_secret_jwt_key")
+ client_secret_jwt_key = None # type: Optional[OidcProviderClientSecretJwtKey]
+ if client_secret_jwt_key_config is not None:
+ keyfile = client_secret_jwt_key_config.get("key_file")
+ if keyfile:
+ key = read_file(keyfile, config_path + ("client_secret_jwt_key",))
+ else:
+ key = client_secret_jwt_key_config["key"]
+ client_secret_jwt_key = OidcProviderClientSecretJwtKey(
+ key=key,
+ jwt_header=client_secret_jwt_key_config["jwt_header"],
+ jwt_payload=client_secret_jwt_key_config.get("jwt_payload", {}),
+ )
+
return OidcProviderConfig(
idp_id=idp_id,
idp_name=oidc_config.get("idp_name", "OIDC"),
@@ -412,7 +469,8 @@ def _parse_oidc_config_dict(
discover=oidc_config.get("discover", True),
issuer=oidc_config["issuer"],
client_id=oidc_config["client_id"],
- client_secret=oidc_config["client_secret"],
+ client_secret=oidc_config.get("client_secret"),
+ client_secret_jwt_key=client_secret_jwt_key,
client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"),
scopes=oidc_config.get("scopes", ["openid"]),
authorization_endpoint=oidc_config.get("authorization_endpoint"),
@@ -428,6 +486,18 @@ def _parse_oidc_config_dict(
@attr.s(slots=True, frozen=True)
+class OidcProviderClientSecretJwtKey:
+ # a pem-encoded signing key
+ key = attr.ib(type=str)
+
+ # properties to include in the JWT header
+ jwt_header = attr.ib(type=Mapping[str, str])
+
+ # properties to include in the JWT payload.
+ jwt_payload = attr.ib(type=Mapping[str, str])
+
+
+@attr.s(slots=True, frozen=True)
class OidcProviderConfig:
# a unique identifier for this identity provider. Used in the 'user_external_ids'
# table, as well as the query/path parameter used in the login protocol.
@@ -452,8 +522,13 @@ class OidcProviderConfig:
# oauth2 client id to use
client_id = attr.ib(type=str)
- # oauth2 client secret to use
- client_secret = attr.ib(type=str)
+ # oauth2 client secret to use. if `None`, use client_secret_jwt_key to generate
+ # a secret.
+ client_secret = attr.ib(type=Optional[str])
+
+ # key to use to construct a JWT to use as a client secret. May be `None` if
+ # `client_secret` is set.
+ client_secret_jwt_key = attr.ib(type=Optional[OidcProviderClientSecretJwtKey])
# auth method to use when exchanging the token.
# Valid values are 'client_secret_basic', 'client_secret_post' and
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 2afca36e7d..5f8910b6e1 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -841,8 +841,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
- # API, so this setting is of limited value if federation is enabled on
- # the server.
+ # API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 93aa199119..f8e368f81b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Iterable,
List,
Optional,
Tuple,
@@ -90,16 +91,15 @@ pdu_process_time = Histogram(
"Time taken to process an event",
)
-
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_received_pdu_age",
- "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_received_pdu_time",
+ "The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@@ -119,7 +119,7 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
- hs, "fed_txn_handler", timeout_ms=30000
+ hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -129,10 +129,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
- hs, "state_resp", timeout_ms=30000
+ hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
- hs, "state_ids_resp", timeout_ms=30000
+ hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@@ -361,7 +361,7 @@ class FederationServer(FederationBase):
logger.error(
"Failed to handle PDU %s",
event_id,
- exc_info=(f.type, f.value, f.getTracebackObject()),
+ exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
await concurrently_execute(
@@ -369,8 +369,7 @@ class FederationServer(FederationBase):
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
- newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
- last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+ last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
@@ -455,7 +454,9 @@ class FederationServer(FederationBase):
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ pdus = await self.handler.get_state_for_pdu(
+ room_id, event_id
+ ) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff296c..2a9cd063c4 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_sent_pdu_age",
- "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_sent_pdu_time",
+ "The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@@ -187,9 +187,8 @@ class TransactionManager:
if success and pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
- last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
- last_pdu_age_metric.labels(server_name=destination).set(
- last_pdu_age / 1000
+ last_pdu_ts_metric.labels(server_name=destination).set(
+ last_pdu.origin_server_ts / 1000
)
set_tag(tags.ERROR, not success)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1ac..132be238dd 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+ self.reactor.listenTCP(
+ self.hs.config.acme_port, srv, backlog=50, interface=host
+ )
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+ """Data we store in a short-term login token"""
+
+ user_id = attr.ib(type=str)
+
+ # the SSO Identity Provider that the user authenticated with, to get this token
+ auth_provider_id = attr.ib(type=str)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
- async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
- auth_api = self.hs.get_auth()
- user_id = None
+ async def validate_short_term_login_token(
+ self, login_token: str
+ ) -> LoginTokenAttributes:
try:
- macaroon = pymacaroons.Macaroon.deserialize(login_token)
- user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", user_id)
+ res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- await self.auth.check_auth_blocking(user_id)
- return user_id
+ await self.auth.check_auth_blocking(res.user_id)
+ return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ auth_provider_id: The id of the SSO Identity provider that was used for
+ login. This will be stored in the login token for future tracking in
+ prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ auth_provider_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
+ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+ """Verify a short-term-login macaroon
+
+ Checks that the given token is a valid, unexpired short-term-login token
+ minted by this server.
+
+ Args:
+ token: the login token to verify
+
+ Returns:
+ the user_id that this token is valid for
+
+ Raises:
+ MacaroonVerificationFailedException if the verification failed
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
+ auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = login")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ satisfy_expiry(v, self.hs.get_clock().time_msec)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 71a5076672..13f8152283 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
- hs, "initial_sync_cache"
+ hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..825fadb76f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,13 +15,13 @@
# limitations under the License.
import inspect
import logging
-from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar
+from typing import TYPE_CHECKING, Dict, Generic, List, Optional, TypeVar, Union
from urllib.parse import urlencode
import attr
import pymacaroons
from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken
+from authlib.jose import JsonWebToken, jwt
from authlib.oauth2.auth import ClientAuth
from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
from authlib.oidc.core import CodeIDToken, ImplicitIDToken, UserInfo
@@ -35,13 +36,17 @@ from typing_extensions import TypedDict
from twisted.web.client import readBody
from synapse.config import ConfigError
-from synapse.config.oidc_config import OidcProviderConfig
+from synapse.config.oidc_config import (
+ OidcProviderClientSecretJwtKey,
+ OidcProviderConfig,
+)
from synapse.handlers.sso import MappingException, UserAttributes
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
-from synapse.util import json_decoder
+from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -211,7 +216,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, ValueError) as e:
+ except (MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -275,9 +280,21 @@ class OidcProvider:
self._scopes = provider.scopes
self._user_profile_method = provider.user_profile_method
+
+ client_secret = None # type: Union[None, str, JwtClientSecret]
+ if provider.client_secret:
+ client_secret = provider.client_secret
+ elif provider.client_secret_jwt_key:
+ client_secret = JwtClientSecret(
+ provider.client_secret_jwt_key,
+ provider.client_id,
+ provider.issuer,
+ hs.get_clock(),
+ )
+
self._client_auth = ClientAuth(
provider.client_id,
- provider.client_secret,
+ client_secret,
provider.client_auth_method,
) # type: ClientAuth
self._client_auth_method = provider.client_auth_method
@@ -745,7 +762,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ ui_auth_session_id=ui_auth_session_id or "",
),
)
@@ -976,6 +993,81 @@ class OidcProvider:
return str(remote_user_id)
+# number of seconds a newly-generated client secret should be valid for
+CLIENT_SECRET_VALIDITY_SECONDS = 3600
+
+# minimum remaining validity on a client secret before we should generate a new one
+CLIENT_SECRET_MIN_VALIDITY_SECONDS = 600
+
+
+class JwtClientSecret:
+ """A class which generates a new client secret on demand, based on a JWK
+
+ This implementation is designed to comply with the requirements for Apple Sign in:
+ https://developer.apple.com/documentation/sign_in_with_apple/generate_and_validate_tokens#3262048
+
+ It looks like those requirements are based on https://tools.ietf.org/html/rfc7523,
+ but it's worth noting that we still put the generated secret in the "client_secret"
+ field (or rather, whereever client_auth_method puts it) rather than in a
+ client_assertion field in the body as that RFC seems to require.
+ """
+
+ def __init__(
+ self,
+ key: OidcProviderClientSecretJwtKey,
+ oauth_client_id: str,
+ oauth_issuer: str,
+ clock: Clock,
+ ):
+ self._key = key
+ self._oauth_client_id = oauth_client_id
+ self._oauth_issuer = oauth_issuer
+ self._clock = clock
+ self._cached_secret = b""
+ self._cached_secret_replacement_time = 0
+
+ def __str__(self):
+ # if client_auth_method is client_secret_basic, then ClientAuth.prepare calls
+ # encode_client_secret_basic, which calls "{}".format(secret), which ends up
+ # here.
+ return self._get_secret().decode("ascii")
+
+ def __bytes__(self):
+ # if client_auth_method is client_secret_post, then ClientAuth.prepare calls
+ # encode_client_secret_post, which ends up here.
+ return self._get_secret()
+
+ def _get_secret(self) -> bytes:
+ now = self._clock.time()
+
+ # if we have enough validity on our existing secret, use it
+ if now < self._cached_secret_replacement_time:
+ return self._cached_secret
+
+ issued_at = int(now)
+ expires_at = issued_at + CLIENT_SECRET_VALIDITY_SECONDS
+
+ # we copy the configured header because jwt.encode modifies it.
+ header = dict(self._key.jwt_header)
+
+ # see https://tools.ietf.org/html/rfc7523#section-3
+ payload = {
+ "sub": self._oauth_client_id,
+ "aud": self._oauth_issuer,
+ "iat": issued_at,
+ "exp": expires_at,
+ **self._key.jwt_payload,
+ }
+ logger.info(
+ "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
+ )
+ self._cached_secret = jwt.encode(header, payload, self._key.key)
+ self._cached_secret_replacement_time = (
+ expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
+ )
+ return self._cached_secret
+
+
class OidcSessionTokenGenerator:
"""Methods for generating and checking OIDC Session cookies."""
@@ -1020,10 +1112,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
- if session_data.ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
- )
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1137,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
- ValueError if an expected caveat is missing from the macaroon.
+ KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -1057,26 +1148,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ idp_id = get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+ ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@@ -1084,33 +1165,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- ValueError: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
- return now < expiry
-
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@@ -1125,8 +1179,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
- # The session ID of the ongoing UI Auth (None if this is a login)
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ # The session ID of the ongoing UI Auth ("" if this is a login)
+ ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 059064a4eb..66dc886c81 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -285,7 +285,7 @@ class PaginationHandler:
except Exception:
f = Failure()
logger.error(
- "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject())
+ "[purge] failed", exc_info=(f.type, f.value, f.getTracebackObject()) # type: ignore
)
self._purges_by_id[purge_id].status = PurgeStatus.STATUS_FAILED
finally:
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 3cda89657e..b66f8756b8 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -18,6 +18,8 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from prometheus_client import Counter
+
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+registration_counter = Counter(
+ "synapse_user_registrations_total",
+ "Number of new users registered (since restart)",
+ ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+ "synapse_user_logins_total",
+ "Number of user logins (since restart)",
+ ["guest", "auth_provider"],
+)
+
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -156,6 +171,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@@ -181,8 +197,10 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
- The registere user_id.
+ The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -280,6 +298,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
+ registration_counter.labels(
+ guest=make_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@@ -638,6 +662,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
+ auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -648,7 +673,8 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
-
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
Tuple of device ID and access token
"""
@@ -687,6 +713,11 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
return (registered_device_id, access_token)
async def post_registration_actions(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a488df10d6..4b3d0d72e3 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
- hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 70522e40fa..8c5b60e6be 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -45,10 +45,10 @@ class RoomListHandler(BaseHandler):
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
- hs, "room_list"
+ hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
- hs, "remote_room_list", timeout_ms=30 * 1000
+ hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..6ef459acff 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@@ -605,6 +606,7 @@ class SsoHandler:
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
+ auth_provider_id=auth_provider_id,
)
await self._store.record_user_external_id(
@@ -886,6 +888,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 6c8e361402..a65299bd22 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -245,7 +245,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
- hs, "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS
+ hs.get_clock(), "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f95..af34d583ad 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
return r
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +325,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
- )
+ ) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c08..5935a125fd 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
- reactor: IReactorCore,
+ reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 4def7d7633..ecd63e6596 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -322,7 +322,8 @@ def _cache_period_from_headers(
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
- for hdr in headers.getRawHeaders(b"cache-control", []):
+ cache_control_headers = headers.getRawHeaders(b"cache-control") or []
+ for hdr in cache_control_headers:
for directive in hdr.split(b","):
splits = [x.strip() for x in directive.split(b"=", 1)]
k = splits[0].lower()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714ea..5f01ebd3d4 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
- )
+ ) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
- self.agent = MatrixFederationAgent(
+ federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
+ federation_agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
@@ -534,9 +534,10 @@ class MatrixFederationHttpClient:
response.code, response_phrase, body
)
- # Retry if the error is a 429 (Too Many Requests),
- # otherwise just raise a standard HttpResponseException
- if response.code == 429:
+ # Retry if the error is a 5xx or a 429 (Too Many
+ # Requests), otherwise just raise a standard
+ # `HttpResponseException`
+ if 500 <= response.code < 600 or response.code == 429:
raise RequestSendFailed(exc, can_retry=True) from exc
else:
raise exc
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index 78e27bfb00..1a7ea4fa96 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -669,7 +669,7 @@ def preserve_fn(f):
return g
-def run_in_background(f, *args, **kwargs):
+def run_in_background(f, *args, **kwargs) -> defer.Deferred:
"""Calls a function, ensuring that the current context is restored after
return from the function, and that the sentinel context is set once the
deferred returned by the function completes.
@@ -697,8 +697,10 @@ def run_in_background(f, *args, **kwargs):
if isinstance(res, types.CoroutineType):
res = defer.ensureDeferred(res)
+ # At this point we should have a Deferred, if not then f was a synchronous
+ # function, wrap it in a Deferred for consistency.
if not isinstance(res, defer.Deferred):
- return res
+ return defer.succeed(res)
if res.called and not res.paused:
# The function should have maintained the logcontext, so we can
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
+ auth_provider_id: str = "",
) -> str:
- """Generate a login token suitable for m.login.token authentication"""
+ """Generate a login token suitable for m.login.token authentication
+
+ Args:
+ user_id: gives the ID of the user that the token is for
+
+ duration_in_ms: the time that the token will be valid for
+
+ auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+ to get this token, if any. This is encoded in the token so that
+ /login can report stats on number of successful logins by IdP.
+ """
return self._hs.get_macaroon_generator().generate_short_term_login_token(
- user_id, duration_in_ms
+ user_id,
+ auth_provider_id,
+ duration_in_ms,
)
@defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
+ "<unknown>",
request,
client_redirect_url,
)
@@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
+ auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
+ auth_provider_id: the ID of the SSO IdP which was used to log in. This
+ is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url, new_user=new_user
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ new_user=new_user,
)
@defer.inlineCallbacks
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 8a3f113e76..b7aa0c280f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
+ hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf53..7560706b4b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -328,6 +328,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, 30)
+ reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d44e..49966ee3e0 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Tuple
+
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Optional, Tuple
+
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
- def register(self, json_resource):
+ def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- async def on_POST(self, request, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request, txn_id):
+ def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
+ auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id,
+ login_submission,
+ self.auth_handler._sso_login_callback,
+ auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py
index 07903e4017..988f52c78f 100644
--- a/synapse/rest/media/v1/thumbnailer.py
+++ b/synapse/rest/media/v1/thumbnailer.py
@@ -96,9 +96,14 @@ class Thumbnailer:
def _resize(self, width: int, height: int) -> Image:
# 1-bit or 8-bit color palette images need converting to RGB
# otherwise they will be scaled using nearest neighbour which
- # looks awful
- if self.image.mode in ["1", "P"]:
- self.image = self.image.convert("RGB")
+ # looks awful.
+ #
+ # If the image has transparency, use RGBA instead.
+ if self.image.mode in ["1", "L", "P"]:
+ mode = "RGB"
+ if self.image.info.get("transparency", None) is not None:
+ mode = "RGBA"
+ self.image = self.image.convert(mode)
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72e7..369cc88026 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import (
cast,
)
-import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
- def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
diff --git a/synapse/types.py b/synapse/types.py
index 721343f0b5..0216d213c7 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -35,6 +35,14 @@ from typing import (
import attr
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+ IReactorCore,
+ IReactorPluggableNameResolver,
+ IReactorTCP,
+ IReactorTime,
+)
from synapse.api.errors import Codes, SynapseError
from synapse.util.stringutils import parse_and_validate_server_name
@@ -67,6 +75,14 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+ IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
+):
+ """The interfaces necessary for Synapse to function."""
+
+
class Requester(
namedtuple(
"Requester",
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 719e35b78d..f33c115844 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().callback(r)
- except Exception:
- pass
+ observer.callback(r)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .callback(%r), ignoring...",
+ observer,
+ r,
+ exc_info=e,
+ )
return r
def errback(f):
@@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().errback(f)
- except Exception:
- pass
+ observer.errback(f)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .errback(%r), ignoring...",
+ observer,
+ f,
+ exc_info=e,
+ )
if consumeErrors:
return None
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..46ea8e0964 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
- def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.clock = hs.get_clock()
+ self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
|