diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index dbf3799d2e..3370bc74cf 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -413,7 +414,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
- user_id = self.get_user_id_from_macaroon(macaroon)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -421,7 +422,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
- except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
+ except (
+ pymacaroons.exceptions.MacaroonException,
+ KeyError,
+ TypeError,
+ ValueError,
+ ):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -429,27 +435,6 @@ class Auth:
return user_id, guest
- def get_user_id_from_macaroon(self, macaroon):
- """Retrieve the user_id given by the caveats on the macaroon.
-
- Does *not* validate the macaroon.
-
- Args:
- macaroon (pymacaroons.Macaroon): The macaroon to validate
-
- Returns:
- (str) user id
-
- Raises:
- InvalidClientCredentialsError if there is no user_id caveat in the
- macaroon
- """
- user_prefix = "user_id = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(user_prefix):
- return caveat.caveat_id[len(user_prefix) :]
- raise InvalidClientTokenError("No user caveat in macaroon")
-
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -470,21 +455,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
- def _verify_expiry(self, caveat):
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self.hs.get_clock().time_msec()
- return now < expiry
-
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 09104488f8..8c5d178e05 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
- hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
+ hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 29ff7718fd..c8b1a25004 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -847,8 +847,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
- # API, so this setting is of limited value if federation is enabled on
- # the server.
+ # API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index de7c2e5f77..26b8105dec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,6 +22,7 @@ from typing import (
Awaitable,
Callable,
Dict,
+ Iterable,
List,
Optional,
Tuple,
@@ -91,16 +92,15 @@ pdu_process_time = Histogram(
"Time taken to process an event",
)
-
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_received_pdu_age",
- "The age (in seconds) of the last PDU successfully received from the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_received_pdu_time",
+ "The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@@ -120,7 +120,7 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
- hs, "fed_txn_handler", timeout_ms=30000
+ hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -130,10 +130,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
- hs, "state_resp", timeout_ms=30000
+ hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
- hs, "state_ids_resp", timeout_ms=30000
+ hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@@ -370,8 +370,7 @@ class FederationServer(FederationBase):
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
- newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
- last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
+ last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
@@ -456,7 +455,9 @@ class FederationServer(FederationBase):
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
- pdus = await self.handler.get_state_for_pdu(room_id, event_id)
+ pdus = await self.handler.get_state_for_pdu(
+ room_id, event_id
+ ) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 763aff296c..2a9cd063c4 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-last_pdu_age_metric = Gauge(
- "synapse_federation_last_sent_pdu_age",
- "The age (in seconds) of the last PDU successfully sent to the given domain",
+last_pdu_ts_metric = Gauge(
+ "synapse_federation_last_sent_pdu_time",
+ "The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@@ -187,9 +187,8 @@ class TransactionManager:
if success and pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
- last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
- last_pdu_age_metric.labels(server_name=destination).set(
- last_pdu_age / 1000
+ last_pdu_ts_metric.labels(server_name=destination).set(
+ last_pdu.origin_server_ts / 1000
)
set_tag(tags.ERROR, not success)
diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py
index 5ecb2da1ac..132be238dd 100644
--- a/synapse/handlers/acme.py
+++ b/synapse/handlers/acme.py
@@ -73,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
- self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
+ self.reactor.listenTCP(
+ self.hs.config.acme_port, srv, backlog=50, interface=host
+ )
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3978e41518..bec0c615d4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
+@attr.s(slots=True, frozen=True)
+class LoginTokenAttributes:
+ """Data we store in a short-term login token"""
+
+ user_id = attr.ib(type=str)
+
+ # the SSO Identity Provider that the user authenticated with, to get this token
+ auth_provider_id = attr.ib(type=str)
+
+
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
- async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
- auth_api = self.hs.get_auth()
- user_id = None
+ async def validate_short_term_login_token(
+ self, login_token: str
+ ) -> LoginTokenAttributes:
try:
- macaroon = pymacaroons.Macaroon.deserialize(login_token)
- user_id = auth_api.get_user_id_from_macaroon(macaroon)
- auth_api.validate_macaroon(macaroon, "login", user_id)
+ res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
- await self.auth.check_auth_blocking(user_id)
- return user_id
+ await self.auth.check_auth_blocking(res.user_id)
+ return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ auth_provider_id: The id of the SSO Identity provider that was used for
+ login. This will be stored in the login token for future tracking in
+ prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
+ auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
- registered_user_id
+ registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ auth_provider_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
+ macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
+ def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
+ """Verify a short-term-login macaroon
+
+ Checks that the given token is a valid, unexpired short-term-login token
+ minted by this server.
+
+ Args:
+ token: the login token to verify
+
+ Returns:
+ the user_id that this token is valid for
+
+ Raises:
+ MacaroonVerificationFailedException if the verification failed
+ """
+ macaroon = pymacaroons.Macaroon.deserialize(token)
+ user_id = get_value_from_macaroon(macaroon, "user_id")
+ auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
+
+ v = pymacaroons.Verifier()
+ v.satisfy_exact("gen = 1")
+ v.satisfy_exact("type = login")
+ v.satisfy_general(lambda c: c.startswith("user_id = "))
+ v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
+ satisfy_expiry(v, self.hs.get_clock().time_msec)
+ v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
+
+ return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
+
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 71a5076672..13f8152283 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
- hs, "initial_sync_cache"
+ hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 07db1e31e4..b4a74390cc 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
- except (MacaroonDeserializationException, ValueError) as e:
+ except (MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -745,7 +746,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
- ui_auth_session_id=ui_auth_session_id,
+ ui_auth_session_id=ui_auth_session_id or "",
),
)
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
- if session_data.ui_auth_session_id:
- macaroon.add_first_party_caveat(
- "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
- )
+ macaroon.add_first_party_caveat(
+ "ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
+ )
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
- ValueError if an expected caveat is missing from the macaroon.
+ KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
- # Sometimes there's a UI auth session ID, it seems to be OK to attempt
- # to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
- v.satisfy_general(self._verify_expiry)
+ satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
- nonce = self._get_value_from_macaroon(macaroon, "nonce")
- idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
- client_redirect_url = self._get_value_from_macaroon(
- macaroon, "client_redirect_url"
- )
- try:
- ui_auth_session_id = self._get_value_from_macaroon(
- macaroon, "ui_auth_session_id"
- ) # type: Optional[str]
- except ValueError:
- ui_auth_session_id = None
-
+ nonce = get_value_from_macaroon(macaroon, "nonce")
+ idp_id = get_value_from_macaroon(macaroon, "idp_id")
+ client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
+ ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
- def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
- """Extracts a caveat value from a macaroon token.
-
- Args:
- macaroon: the token
- key: the key of the caveat to extract
-
- Returns:
- The extracted value
-
- Raises:
- ValueError: if the caveat was not in the macaroon
- """
- prefix = key + " = "
- for caveat in macaroon.caveats:
- if caveat.caveat_id.startswith(prefix):
- return caveat.caveat_id[len(prefix) :]
- raise ValueError("No %s caveat in macaroon" % (key,))
-
- def _verify_expiry(self, caveat: str) -> bool:
- prefix = "time < "
- if not caveat.startswith(prefix):
- return False
- expiry = int(caveat[len(prefix) :])
- now = self._clock.time_msec()
- return now < expiry
-
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
- # The session ID of the ongoing UI Auth (None if this is a login)
- ui_auth_session_id = attr.ib(type=Optional[str], default=None)
+ # The session ID of the ongoing UI Auth ("" if this is a login)
+ ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 553fcb5b66..798c29748f 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -18,6 +18,8 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from prometheus_client import Counter
+
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+registration_counter = Counter(
+ "synapse_user_registrations_total",
+ "Number of new users registered (since restart)",
+ ["guest", "shadow_banned", "auth_provider"],
+)
+
+login_counter = Counter(
+ "synapse_user_logins_total",
+ "Number of user logins (since restart)",
+ ["guest", "auth_provider"],
+)
+
+
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -171,6 +186,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+ auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@@ -196,8 +212,10 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
- The registere user_id.
+ The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -304,6 +322,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
+ registration_counter.labels(
+ guest=make_guest,
+ shadow_banned=shadow_banned,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@@ -718,6 +742,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
+ auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -728,7 +753,8 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
-
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
Tuple of device ID and access token
"""
@@ -767,6 +793,11 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
+ login_counter.labels(
+ guest=is_guest,
+ auth_provider=(auth_provider_id or ""),
+ ).inc()
+
return (registered_device_id, access_token)
async def post_registration_actions(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 2271c60afc..f3da38a71e 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
- hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
+ hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 373b9dcd0d..6be7b97019 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
- hs, "room_list"
+ hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
- hs, "remote_room_list", timeout_ms=30 * 1000
+ hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 80e28bdcbe..6ef459acff 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -456,6 +456,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@@ -605,6 +606,7 @@ class SsoHandler:
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
+ auth_provider_id=auth_provider_id,
)
await self._store.record_user_external_id(
@@ -886,6 +888,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
+ session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 9059382246..603349bd2a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -258,7 +258,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
- hs, "sync"
+ hs.get_clock(), "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 72901e3f95..af34d583ad 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
return r
-@implementer(IReactorPluggableNameResolver)
+@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +325,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
- )
+ ) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index b07aa59c08..5935a125fd 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
- reactor: IReactorCore,
+ reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 0f107714ea..da6866addf 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
-from synapse.types import JsonDict
+from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
- )
+ ) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
- self.agent = MatrixFederationAgent(
+ federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
- self.agent,
+ federation_agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index db2d400b7e..781e02fbbb 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
- self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+ self,
+ user_id: str,
+ duration_in_ms: int = (2 * 60 * 1000),
+ auth_provider_id: str = "",
) -> str:
- """Generate a login token suitable for m.login.token authentication"""
+ """Generate a login token suitable for m.login.token authentication
+
+ Args:
+ user_id: gives the ID of the user that the token is for
+
+ duration_in_ms: the time that the token will be valid for
+
+ auth_provider_id: the ID of the SSO IdP that the user used to authenticate
+ to get this token, if any. This is encoded in the token so that
+ /login can report stats on number of successful logins by IdP.
+ """
return self._hs.get_macaroon_generator().generate_short_term_login_token(
- user_id, duration_in_ms
+ user_id,
+ auth_provider_id,
+ duration_in_ms,
)
@defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
+ "<unknown>",
request,
client_redirect_url,
)
@@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
+ auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
+ auth_provider_id: the ID of the SSO IdP which was used to log in. This
+ is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
- registered_user_id, request, client_redirect_url, new_user=new_user
+ registered_user_id,
+ auth_provider_id,
+ request,
+ client_redirect_url,
+ new_user=new_user,
)
@defer.inlineCallbacks
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 8a3f113e76..b7aa0c280f 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
-from typing import Dict, List, Tuple
+from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
- hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
+ hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 0e6155cf53..7560706b4b 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -328,6 +328,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
- reactor.connectTCP(host, port, factory, 30)
+ reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
return factory.handler
diff --git a/synapse/rest/admin/purge_room_servlet.py b/synapse/rest/admin/purge_room_servlet.py
index 8b7bb6d44e..49966ee3e0 100644
--- a/synapse/rest/admin/purge_room_servlet.py
+++ b/synapse/rest/admin/purge_room_servlet.py
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Tuple
+
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
+from synapse.types import JsonDict
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request):
+ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
diff --git a/synapse/rest/admin/server_notice_servlet.py b/synapse/rest/admin/server_notice_servlet.py
index 375d055445..f495666f4a 100644
--- a/synapse/rest/admin/server_notice_servlet.py
+++ b/synapse/rest/admin/server_notice_servlet.py
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import TYPE_CHECKING, Optional, Tuple
+
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
+from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
-from synapse.types import UserID
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
- def __init__(self, hs):
- """
- Args:
- hs (synapse.server.HomeServer): server
- """
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
- def register(self, json_resource):
+ def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
- async def on_POST(self, request, txn_id=None):
+ async def on_POST(
+ self, request: SynapseRequest, txn_id: Optional[str] = None
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
- def on_PUT(self, request, txn_id):
+ def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 925edfc402..34bc1bd49b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
+ auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
+ auth_provider_id: The SSO IdP the user used, if any (just used for the
+ prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
- user_id, device_id, initial_display_name
+ user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
- user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
- token
- )
+ res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
- user_id, login_submission, self.auth_handler._sso_login_callback
+ res.user_id,
+ login_submission,
+ self.auth_handler._sso_login_callback,
+ auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
diff --git a/synapse/server.py b/synapse/server.py
index afd7cd72e7..369cc88026 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -36,7 +36,6 @@ from typing import (
cast,
)
-import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
-from synapse.types import DomainSpecificString
+from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
- def get_reactor(self) -> twisted.internet.base.ReactorBase:
+ def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
diff --git a/synapse/types.py b/synapse/types.py
index b629976853..6a41a3665d 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -36,6 +36,14 @@ import attr
from signedjson.key import decode_verify_key_bytes
from six.moves import filter
from unpaddedbase64 import decode_base64
+from zope.interface import Interface
+
+from twisted.internet.interfaces import (
+ IReactorCore,
+ IReactorPluggableNameResolver,
+ IReactorTCP,
+ IReactorTime,
+)
from synapse.api.errors import Codes, SynapseError
from synapse.util.stringutils import parse_and_validate_server_name
@@ -68,6 +76,14 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
+# Note that this seems to require inheriting *directly* from Interface in order
+# for mypy-zope to realize it is an interface.
+class ISynapseReactor(
+ IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
+):
+ """The interfaces necessary for Synapse to function."""
+
+
class Requester(
namedtuple(
"Requester",
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 719e35b78d..f33c115844 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().callback(r)
- except Exception:
- pass
+ observer.callback(r)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .callback(%r), ignoring...",
+ observer,
+ r,
+ exc_info=e,
+ )
return r
def errback(f):
@@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
+ observer = self._observers.pop()
try:
- # TODO: Handle errors here.
- self._observers.pop().errback(f)
- except Exception:
- pass
+ observer.errback(f)
+ except Exception as e:
+ logger.exception(
+ "%r threw an exception on .errback(%r), ignoring...",
+ observer,
+ f,
+ exc_info=e,
+ )
if consumeErrors:
return None
diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py
index 32228f42ee..46ea8e0964 100644
--- a/synapse/util/caches/response_cache.py
+++ b/synapse/util/caches/response_cache.py
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
+from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
+from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
-if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
-
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
- def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
+ def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
- self.clock = hs.get_clock()
+ self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py
new file mode 100644
index 0000000000..12cdd53327
--- /dev/null
+++ b/synapse/util/macaroons.py
@@ -0,0 +1,89 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 Quentin Gliech
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for manipulating macaroons"""
+
+from typing import Callable, Optional
+
+import pymacaroons
+from pymacaroons.exceptions import MacaroonVerificationFailedException
+
+
+def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
+ """Extracts a caveat value from a macaroon token.
+
+ Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
+ and returns the extracted value.
+
+ Args:
+ macaroon: the token
+ key: the key of the caveat to extract
+
+ Returns:
+ The extracted value
+
+ Raises:
+ MacaroonVerificationFailedException: if there are conflicting values for the
+ caveat in the macaroon, or if the caveat was not found in the macaroon.
+ """
+ prefix = key + " = "
+ result = None # type: Optional[str]
+ for caveat in macaroon.caveats:
+ if not caveat.caveat_id.startswith(prefix):
+ continue
+
+ val = caveat.caveat_id[len(prefix) :]
+
+ if result is None:
+ # first time we found this caveat: record the value
+ result = val
+ elif val != result:
+ # on subsequent occurrences, raise if the value is different.
+ raise MacaroonVerificationFailedException(
+ "Conflicting values for caveat " + key
+ )
+
+ if result is not None:
+ return result
+
+ # If the caveat is not there, we raise a MacaroonVerificationFailedException.
+ # Note that it is insecure to generate a macaroon without all the caveats you
+ # might need (because there is nothing stopping people from adding extra caveats),
+ # so if the caveat isn't there, something odd must be going on.
+ raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
+
+
+def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
+ """Make a macaroon verifier which accepts 'time' caveats
+
+ Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
+ the given macaroon verifier.
+
+ Args:
+ v: the macaroon verifier
+ get_time_ms: a callable which will return the timestamp after which the caveat
+ should be considered expired. Normally the current time.
+ """
+
+ def verify_expiry_caveat(caveat: str):
+ time_msec = get_time_ms()
+ prefix = "time < "
+ if not caveat.startswith(prefix):
+ return False
+ expiry = int(caveat[len(prefix) :])
+ return time_msec < expiry
+
+ v.satisfy_general(verify_expiry_caveat)
|