summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8911.feature1
-rw-r--r--changelog.d/8916.misc2
-rw-r--r--changelog.d/8935.misc1
-rw-r--r--changelog.d/8937.bugfix1
-rw-r--r--changelog.d/8938.feature1
-rw-r--r--changelog.d/8943.misc1
-rw-r--r--synapse/config/logger.py2
-rw-r--r--synapse/handlers/auth.py26
-rw-r--r--synapse/handlers/saml_handler.py23
-rw-r--r--synapse/http/site.py3
-rw-r--r--synapse/logging/context.py24
-rw-r--r--synapse/metrics/background_process_metrics.py7
-rw-r--r--synapse/push/__init__.py19
-rw-r--r--synapse/push/emailpusher.py16
-rw-r--r--synapse/push/httppusher.py17
-rw-r--r--synapse/push/pusherpool.py5
-rw-r--r--synapse/replication/tcp/protocol.py3
-rw-r--r--synapse/storage/databases/main/event_push_actions.py10
-rw-r--r--tests/handlers/test_federation.py6
-rw-r--r--tests/handlers/test_oidc.py296
-rw-r--r--tests/handlers/test_password_providers.py23
-rw-r--r--tests/handlers/test_saml.py132
-rw-r--r--tests/logging/test_terse_json.py7
-rw-r--r--tests/test_federation.py2
-rw-r--r--tests/test_utils/__init__.py12
-rw-r--r--tests/test_utils/logging_setup.py2
26 files changed, 345 insertions, 297 deletions
diff --git a/changelog.d/8911.feature b/changelog.d/8911.feature
new file mode 100644

index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8911.feature
@@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc
index c71ef480e6..bf94135fd5 100644 --- a/changelog.d/8916.misc +++ b/changelog.d/8916.misc
@@ -1 +1 @@ -Improve structured logging tests. +Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8935.misc b/changelog.d/8935.misc new file mode 100644
index 0000000000..bf94135fd5 --- /dev/null +++ b/changelog.d/8935.misc
@@ -0,0 +1 @@ +Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8937.bugfix b/changelog.d/8937.bugfix new file mode 100644
index 0000000000..01e1848448 --- /dev/null +++ b/changelog.d/8937.bugfix
@@ -0,0 +1 @@ +Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file. diff --git a/changelog.d/8938.feature b/changelog.d/8938.feature new file mode 100644
index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8938.feature
@@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/changelog.d/8943.misc b/changelog.d/8943.misc new file mode 100644
index 0000000000..4ff0b94b94 --- /dev/null +++ b/changelog.d/8943.misc
@@ -0,0 +1 @@ +Add type hints to push module. diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_context_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d58dc3cc29..f385c72526 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler): self._password_enabled = hs.config.password_enabled self._password_localdb_enabled = hs.config.password_localdb_enabled - # we keep this as a list despite the O(N^2) implication so that we can - # keep PASSWORD first and avoid confusing clients which pick the first - # type in the list. (NB that the spec doesn't require us to do so and - # clients which favour types that they don't understand over those that - # they do are technically broken) - # start out by assuming PASSWORD is enabled; we will remove it later if not. - login_types = [] + login_types = set() if self._password_localdb_enabled: - login_types.append(LoginType.PASSWORD) + login_types.add(LoginType.PASSWORD) for provider in self.password_providers: - if hasattr(provider, "get_supported_login_types"): - for t in provider.get_supported_login_types().keys(): - if t not in login_types: - login_types.append(t) + login_types.update(provider.get_supported_login_types().keys()) if not self._password_enabled: + login_types.discard(LoginType.PASSWORD) + + # Some clients just pick the first type in the list. In this case, we want + # them to use PASSWORD (rather than token or whatever), so we want to make sure + # that comes first, where it's present. + self._supported_login_types = [] + if LoginType.PASSWORD in login_types: + self._supported_login_types.append(LoginType.PASSWORD) login_types.remove(LoginType.PASSWORD) - - self._supported_login_types = login_types + self._supported_login_types.extend(login_types) # Ratelimiter for failed auth during UIA. Uses same ratelimit config # as per `rc_login.failed_attempts`. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..6001fe3e27 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler): return logger.debug("SAML2 response: %s", saml2_auth.origxml) + + await self._handle_authn_response(request, saml2_auth, relay_state) + + async def _handle_authn_response( + self, + request: SynapseRequest, + saml2_auth: saml2.response.AuthnResponse, + relay_state: str, + ) -> None: + """Handle an AuthnResponse, having parsed it from the request params + + Assumes that the signature on the response object has been checked. Maps + the user onto an MXID, registering them if necessary, and returns a response + to the browser. + + Args: + request: the incoming request from the browser. We'll respond to it with an + HTML page or a redirect + saml2_auth: the parsed AuthnResponse object + relay_state: the RelayState query param, which encodes the URI to rediret + back to + """ + for assertion in saml2_auth.assertions: # kibana limits the length of a log field, whereas this is all rather # useful, so split it up. diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 76b7decf26..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py
@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: ctx = noop_context_manager() if bg_start_span: @@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 3d2e874838..ad07ee86f6 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py
@@ -14,7 +14,7 @@ # limitations under the License. import abc -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any, Dict from synapse.types import RoomStreamToken @@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta): # This is the highest stream ordering we know it's safe to process. # When new events arrive, we'll be given a window of new events: we # should honour this rather than just looking for anything higher - # because of potential out-of-order event serialisation. This starts - # off as None though as we don't know any better. - self.max_stream_ordering = None # type: Optional[int] + # because of potential out-of-order event serialisation. + self.max_stream_ordering = self.store.get_room_max_stream_ordering() - @abc.abstractmethod def on_new_notifications(self, max_token: RoomStreamToken) -> None: + # We just use the minimum stream ordering and ignore the vector clock + # component. This is safe to do as long as we *always* ignore the vector + # clock components. + max_stream_ordering = max_token.stream + + self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering) + self._start_processing() + + @abc.abstractmethod + def _start_processing(self): + """Start processing push notifications.""" raise NotImplementedError() @abc.abstractmethod diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 64a35c1994..11a97b8df4 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py
@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher from synapse.push.mailer import Mailer -from synapse.types import RoomStreamToken if TYPE_CHECKING: from synapse.app.homeserver import HomeServer @@ -93,20 +92,6 @@ class EmailPusher(Pusher): pass self.timed_call = None - def on_new_notifications(self, max_token: RoomStreamToken) -> None: - # We just use the minimum stream ordering and ignore the vector clock - # component. This is safe to do as long as we *always* ignore the vector - # clock components. - max_stream_ordering = max_token.stream - - if self.max_stream_ordering: - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering - ) - else: - self.max_stream_ordering = max_stream_ordering - self._start_processing() - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # We could wake up and cancel the timer but there tend to be quite a # lot of read receipts so it's probably less work to just let the @@ -172,7 +157,6 @@ class EmailPusher(Pusher): being run. """ start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering - assert self.max_stream_ordering is not None unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email( self.user_id, start, self.max_stream_ordering ) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 5408aa1295..e8b25bcd2a 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py
@@ -26,7 +26,6 @@ from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfigException -from synapse.types import RoomStreamToken from . import push_rule_evaluator, push_tools @@ -122,17 +121,6 @@ class HttpPusher(Pusher): if should_check_for_notifs: self._start_processing() - def on_new_notifications(self, max_token: RoomStreamToken) -> None: - # We just use the minimum stream ordering and ignore the vector clock - # component. This is safe to do as long as we *always* ignore the vector - # clock components. - max_stream_ordering = max_token.stream - - self.max_stream_ordering = max( - max_stream_ordering, self.max_stream_ordering or 0 - ) - self._start_processing() - def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None: # Note that the min here shouldn't be relied upon to be accurate. @@ -192,10 +180,7 @@ class HttpPusher(Pusher): Never call this directly: use _process which will only allow this to run once per pusher. """ - - fn = self.store.get_unread_push_actions_for_user_in_range_for_http - assert self.max_stream_ordering is not None - unprocessed = await fn( + unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http( self.user_id, self.last_stream_ordering, self.max_stream_ordering ) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 748b1407c6..409e46da8f 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py
@@ -129,9 +129,8 @@ class PusherPool: ) # create the pusher setting last_stream_ordering to the current maximum - # stream ordering in event_push_actions, so it will process - # pushes from this point onwards. - last_stream_ordering = await self.store.get_latest_push_action_stream_ordering() + # stream ordering, so it will process pushes from this point onwards. + last_stream_ordering = self.store.get_room_max_stream_ordering() await self.store.add_pusher( user_id=user_id, diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..e5c03cc609 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - async def get_latest_push_action_stream_ordering(self): - def f(txn): - txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions") - return txn.fetchone() - - result = await self.db_pool.runInteraction( - "get_latest_push_action_stream_ordering", f - ) - return result[0] or 0 - def _remove_old_push_actions_before_txn( self, txn, room_id, user_id, stream_ordering ): diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d0452e1490..0b24b89a2e 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): # the auth code requires that a signature exists, but doesn't check that # signature... go figure. join_event.signatures[other_server] = {"x": "y"} - with LoggingContext(request="send_join"): + with LoggingContext("send_join"): d = run_in_background( self.handler.on_send_join_request, other_server, join_event ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1d99a45436..464e569ac8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py
@@ -15,7 +15,7 @@ import json from urllib.parse import parse_qs, urlparse -from mock import Mock, patch +from mock import ANY, Mock, patch import pymacaroons @@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider from synapse.handlers.sso import MappingException from synapse.types import UserID -from tests.test_utils import FakeResponse +from tests.test_utils import FakeResponse, simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider): } -def simple_async_mock(return_value=None, raises=None): - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args, **kwargs): - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - async def get_json(url): # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: @@ -160,6 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(args[2], error_description) # Reset the render_error mock self.render_error.reset_mock() + return args def test_config(self): """Basic config correctly sets up the callback URL and client auth correctly.""" @@ -374,26 +365,17 @@ class OidcHandlerTestCase(HomeserverTestCase): "id_token": "id_token", "access_token": "access_token", } + username = "bar" userinfo = { "sub": "foo", - "preferred_username": "bar", + "username": username, } - user_id = "@foo:domain.org" + expected_user_id = "@%s:%s" % (username, self.hs.hostname) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() code = "code" state = "state" @@ -401,64 +383,54 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [code.encode("utf-8")] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = ip_address - request.get_user_agent.return_value = user_agent + request = self._build_callback_request( + code, state, session, user_agent=user_agent, ip_address=ip_address + ) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_not_called() self.render_error.assert_not_called() # Handle mapping errors - self.handler._map_userinfo_to_user = simple_async_mock( - raises=MappingException() - ) - self.get_success(self.handler.handle_oidc_callback(request)) - self.assertRenderedError("mapping_error") - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + with patch.object( + self.handler, + "_remote_id_from_userinfo", + new=Mock(side_effect=MappingException()), + ): + self.get_success(self.handler.handle_oidc_callback(request)) + self.assertRenderedError("mapping_error") # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") - self.handler._auth_handler.complete_sso_login.reset_mock() + auth_handler.complete_sso_login.reset_mock() self.handler._exchange_code.reset_mock() self.handler._parse_id_token.reset_mock() - self.handler._map_userinfo_to_user.reset_mock() self.handler._fetch_userinfo.reset_mock() # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {}, + auth_handler.complete_sso_login.assert_called_once_with( + expected_user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with( - userinfo, token, user_agent, ip_address - ) self.handler._fetch_userinfo.assert_called_once_with(token) self.render_error.assert_not_called() @@ -609,72 +581,55 @@ class OidcHandlerTestCase(HomeserverTestCase): } userinfo = { "sub": "foo", + "username": "foo", "phone": "1234567", } - user_id = "@foo:domain.org" self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) - self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) - self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock( - spec=[ - "args", - "getCookie", - "addCookie", - "requestHeaders", - "getClientIP", - "get_user_agent", - ] - ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() state = "state" client_redirect_url = "http://client/redirect" - request.getCookie.return_value = self.handler._generate_oidc_session_token( + session = self.handler._generate_oidc_session_token( state=state, nonce="nonce", client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - - request.args = {} - request.args[b"code"] = [b"code"] - request.args[b"state"] = [state.encode("utf-8")] - - request.getClientIP.return_value = "10.0.0.1" - request.get_user_agent.return_value = "Browser" + request = self._build_callback_request("code", state, session) self.get_success(self.handler.handle_oidc_callback(request)) - self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, {"phone": "1234567"}, + auth_handler.complete_sso_login.assert_called_once_with( + "@foo:test", request, client_redirect_url, {"phone": "1234567"}, ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + userinfo = { "sub": "test_user", "username": "test_user", } - # The token doesn't matter with the default user mapping provider. - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", ANY, ANY, {} ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Some providers return an integer ID. userinfo = { "sub": 1234, "username": "test_user_2", } - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user_2:test", ANY, ANY, {} ) - self.assertEqual(mxid, "@test_user_2:test") + auth_handler.complete_sso_login.reset_mock() # Test if the mxid is already taken store = self.hs.get_datastore() @@ -683,14 +638,11 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user3.to_string(), password_hash=None) ) userinfo = {"sub": "test3", "username": "test_user_3"} - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) - self.assertEqual( - str(e.value), "Mapping provider does not support de-duplicating Matrix IDs", + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", + "Mapping provider does not support de-duplicating Matrix IDs", ) @override_config({"oidc_config": {"allow_existing_users": True}}) @@ -702,26 +654,26 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user.to_string(), password_hash=None) ) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. userinfo = { "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Note that a second SSO user can be mapped to the same Matrix ID. (This # requires a unique sub, but something that maps to the same matrix ID, @@ -732,13 +684,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test1", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + user.to_string(), ANY, ANY, {}, ) - self.assertEqual(mxid, "@test_user:test") + auth_handler.complete_sso_login.reset_mock() # Register some non-exact matching cases. user2 = UserID.from_string("@TEST_user_2:test") @@ -755,14 +705,11 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test2", "username": "TEST_USER_2", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + args = self.assertRenderedError("mapping_error") self.assertTrue( - str(e.value).startswith( + args[2].startswith( "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:" ) ) @@ -773,28 +720,15 @@ class OidcHandlerTestCase(HomeserverTestCase): store.register_user(user_id=user2.to_string(), password_hash=None) ) - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_called_once_with( + "@TEST_USER_2:test", ANY, ANY, {}, ) - self.assertEqual(mxid, "@TEST_USER_2:test") def test_map_userinfo_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" - userinfo = { - "sub": "test2", - "username": "föö", - } - token = {} - - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, - ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + self._make_callback_with_userinfo({"sub": "test2", "username": "föö"}) + self.assertRenderedError("mapping_error", "localpart is invalid: föö") @override_config( { @@ -807,6 +741,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) def test_map_userinfo_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) @@ -815,14 +752,13 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "test", "username": "test_user", } - token = {} - mxid = self.get_success( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ) - ) + self._make_callback_with_userinfo(userinfo) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", ANY, ANY, {}, + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular OIDC username. self.get_success( @@ -838,12 +774,70 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "tester", "username": "tester", } - e = self.get_failure( - self.handler._map_userinfo_to_user( - userinfo, token, "user-agent", "10.10.10.10" - ), - MappingException, + self._make_callback_with_userinfo(userinfo) + auth_handler.complete_sso_login.assert_not_called() + self.assertRenderedError( + "mapping_error", "Unable to generate a Matrix ID from the SSO response" + ) + + def _make_callback_with_userinfo( + self, userinfo: dict, client_redirect_url: str = "http://client/redirect" + ) -> None: + self.handler._exchange_code = simple_async_mock(return_value={}) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + state = "state" + session = self.handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + request = self._build_callback_request("code", state, session) + + self.get_success(self.handler.handle_oidc_callback(request)) + + def _build_callback_request( + self, + code: str, + state: str, + session: str, + user_agent: str = "Browser", + ip_address: str = "10.0.0.1", + ): + """Builds a fake SynapseRequest to mock the browser callback + + Returns a Mock object which looks like the SynapseRequest we get from a browser + after SSO (before we return to the client) + + Args: + code: the authorization code which would have been returned by the OIDC + provider + state: the "state" param which would have been passed around in the + query param. Should be the same as was embedded in the session in + _build_oidc_session. + session: the "session" which would have been passed around in the cookie. + user_agent: the user-agent to present + ip_address: the IP address to pretend the request came from + """ + request = Mock( + spec=[ + "args", + "getCookie", + "addCookie", + "requestHeaders", + "getClientIP", + "get_user_agent", + ] ) + + request.getCookie.return_value = session + request.args = {} + request.args[b"code"] = [code.encode("utf-8")] + request.args[b"state"] = [state.encode("utf-8")] + request.getClientIP.return_value = ip_address + request.get_user_agent.return_value = user_agent + return request diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..8d50265145 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { + **providers_config(CustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled(self): + """Check the localdb_enabled == enabled == False + + Regression test for https://github.com/matrix-org/synapse/issues/8914: check + that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't + cause an exception. + """ + self.register_user("localuser", "localpass") + + flows = self._get_login_flows() + self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS) + + # login shouldn't work and should be rejected with a 400 ("unknown login type") + channel = self._send_password_login("localuser", "localpass") + self.assertEqual(channel.code, 400, channel.result) + mock_password_provider.check_auth.assert_not_called() + + @override_config( + { **providers_config(PasswordCustomAuthProvider), "password_config": {"enabled": False}, } diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index d21e5588ca..69927cf6be 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py
@@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional + +from mock import Mock + import attr from synapse.api.errors import RedirectException -from synapse.handlers.sso import MappingException +from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -44,6 +48,8 @@ BASE_URL = "https://synapse/" @attr.s class FakeAuthnResponse: ava = attr.ib(type=dict) + assertions = attr.ib(type=list, factory=list) + in_response_to = attr.ib(type=Optional[str], default=None) class TestMappingProvider: @@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase): def test_map_saml_response_to_user(self): """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) - # The redirect_url doesn't matter with the default user mapping provider. - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "redirect_uri") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "redirect_uri" ) - self.assertEqual(mxid, "@test_user:test") @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) def test_map_saml_response_to_existing_user(self): @@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase): store.register_user(user_id="@test_user:test", password_hash=None) ) + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + # Map a user via SSO. saml_response = FakeAuthnResponse( {"uid": "tester", "mxid": ["test_user"], "username": "test_user"} ) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + + # check that the auth handler got called as expected + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") # Subsequent calls should map to the same mxid. - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + auth_handler.complete_sso_login.reset_mock() + self.get_success( + self.handler._handle_authn_response(request, saml_response, "") + ) + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user:test", request, "" ) - self.assertEqual(mxid, "@test_user:test") def test_map_saml_response_to_invalid_localpart(self): """If the mapping provider generates an invalid localpart it should be rejected.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + + # mock out the error renderer too + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) - redirect_url = "" - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), + ) + sso_handler.render_error.assert_called_once_with( + request, "mapping_error", "localpart is invalid: föö" ) - self.assertEqual(str(e.value), "localpart is invalid: föö") + auth_handler.complete_sso_login.assert_not_called() def test_map_saml_response_to_user_retries(self): """The mapping provider can retry generating an MXID if the MXID is already in use.""" + + # stub out the auth handler and error renderer + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = simple_async_mock() + sso_handler = self.hs.get_sso_handler() + sso_handler.render_error = Mock(return_value=None) + + # register a user to occupy the first-choice MXID store = self.hs.get_datastore() self.get_success( store.register_user(user_id="@test_user:test", password_hash=None) ) + + # send the fake SAML response saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" - mxid = self.get_success( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ) + request = _mock_request() + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) + # test_user is already taken, so test_user1 gets registered instead. - self.assertEqual(mxid, "@test_user1:test") + auth_handler.complete_sso_login.assert_called_once_with( + "@test_user1:test", request, "" + ) + auth_handler.complete_sso_login.reset_mock() # Register all of the potential mxids for a particular SAML username. self.get_success( @@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase): # Now attempt to map to a username, this will fail since all potential usernames are taken. saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"}) - e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), - MappingException, + self.get_success( + self.handler._handle_authn_response(request, saml_response, ""), ) - self.assertEqual( - str(e.value), "Unable to generate a Matrix ID from the SSO response" + sso_handler.render_error.assert_called_once_with( + request, + "mapping_error", + "Unable to generate a Matrix ID from the SSO response", ) + auth_handler.complete_sso_login.assert_not_called() @override_config( { @@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase): } ) def test_map_saml_response_redirect(self): + """Test a mapping provider that raises a RedirectException""" + saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) - redirect_url = "" + request = _mock_request() e = self.get_failure( - self.handler._map_saml_response_to_user( - saml_response, redirect_url, "user-agent", "10.10.10.10" - ), + self.handler._handle_authn_response(request, saml_response, ""), RedirectException, ) self.assertEqual(e.value.location, b"https://custom-saml-redirect/") + + +def _mock_request(): + """Returns a mock which will stand in as a SynapseRequest""" + return Mock(spec=["getClientIP", "get_user_agent"]) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f6e7e5fdaa..48a74e2eee 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py
@@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ handler = logging.StreamHandler(self.output) handler.setFormatter(JsonFormatter()) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext(request="test"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): "level", "namespace", "request", - "scope", ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "test") - self.assertIsNone(log["scope"]) diff --git a/tests/test_federation.py b/tests/test_federation.py
index fa45f8b3b7..fc9aab32d0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(request="lying_event"): + with LoggingContext(): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 6873d45eb6..43898d8142 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py
@@ -22,6 +22,8 @@ import warnings from asyncio import Future from typing import Any, Awaitable, Callable, TypeVar +from mock import Mock + import attr from twisted.python.failure import Failure @@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup +def simple_async_mock(return_value=None, raises=None) -> Mock: + # AsyncMock is not available in python3.5, this mimics part of its behaviour + async def cb(*args, **kwargs): + if raises: + raise raises + return return_value + + return Mock(side_effect=cb) + + @attr.s class FakeResponse: """A fake twisted.web.IResponse object diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging(): handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) root_logger.addHandler(handler) log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")