diff --git a/changelog.d/8911.feature b/changelog.d/8911.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8911.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc
index c71ef480e6..bf94135fd5 100644
--- a/changelog.d/8916.misc
+++ b/changelog.d/8916.misc
@@ -1 +1 @@
-Improve structured logging tests.
+Various clean-ups to the structured logging and logging context code.
diff --git a/changelog.d/8935.misc b/changelog.d/8935.misc
new file mode 100644
index 0000000000..bf94135fd5
--- /dev/null
+++ b/changelog.d/8935.misc
@@ -0,0 +1 @@
+Various clean-ups to the structured logging and logging context code.
diff --git a/changelog.d/8937.bugfix b/changelog.d/8937.bugfix
new file mode 100644
index 0000000000..01e1848448
--- /dev/null
+++ b/changelog.d/8937.bugfix
@@ -0,0 +1 @@
+Fix bug introduced in Synapse v1.24.0 which would cause an exception on startup if both `enabled` and `localdb_enabled` were set to `False` in the `password_config` setting of the configuration file.
diff --git a/changelog.d/8938.feature b/changelog.d/8938.feature
new file mode 100644
index 0000000000..d450ef4998
--- /dev/null
+++ b/changelog.d/8938.feature
@@ -0,0 +1 @@
+Add support for allowing users to pick their own user ID during a single-sign-on login.
diff --git a/changelog.d/8943.misc b/changelog.d/8943.misc
new file mode 100644
index 0000000000..4ff0b94b94
--- /dev/null
+++ b/changelog.d/8943.misc
@@ -0,0 +1 @@
+Add type hints to push module.
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..4df3f93c1c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_context_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index d58dc3cc29..f385c72526 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -198,27 +198,25 @@ class AuthHandler(BaseHandler):
self._password_enabled = hs.config.password_enabled
self._password_localdb_enabled = hs.config.password_localdb_enabled
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
-
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
+ login_types = set()
if self._password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index f2ca1ddb53..6001fe3e27 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -163,6 +163,29 @@ class SamlHandler(BaseHandler):
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
+
+ await self._handle_authn_response(request, saml2_auth, relay_state)
+
+ async def _handle_authn_response(
+ self,
+ request: SynapseRequest,
+ saml2_auth: saml2.response.AuthnResponse,
+ relay_state: str,
+ ) -> None:
+ """Handle an AuthnResponse, having parsed it from the request params
+
+ Assumes that the signature on the response object has been checked. Maps
+ the user onto an MXID, registering them if necessary, and returns a response
+ to the browser.
+
+ Args:
+ request: the incoming request from the browser. We'll respond to it with an
+ HTML page or a redirect
+ saml2_auth: the parsed AuthnResponse object
+ relay_state: the RelayState query param, which encodes the URI to rediret
+ back to
+ """
+
for assertion in saml2_auth.assertions:
# kibana limits the length of a log field, whereas this is all rather
# useful, so split it up.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..5a5790831b 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request
request_id = self.get_request_id()
- logcontext = self.logcontext = LoggingContext(request_id)
- logcontext.request = request_id
+ self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request
return True
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 76b7decf26..70e0fa45d9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc) as context:
- context.request = "%s-%i" % (desc, count)
+ with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
@@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str):
- super().__init__(name)
+ def __init__(self, name: str, request: Optional[str] = None):
+ super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 3d2e874838..ad07ee86f6 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Any, Dict
from synapse.types import RoomStreamToken
@@ -36,12 +36,21 @@ class Pusher(metaclass=abc.ABCMeta):
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
- # because of potential out-of-order event serialisation. This starts
- # off as None though as we don't know any better.
- self.max_stream_ordering = None # type: Optional[int]
+ # because of potential out-of-order event serialisation.
+ self.max_stream_ordering = self.store.get_room_max_stream_ordering()
- @abc.abstractmethod
def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+ # We just use the minimum stream ordering and ignore the vector clock
+ # component. This is safe to do as long as we *always* ignore the vector
+ # clock components.
+ max_stream_ordering = max_token.stream
+
+ self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
+ self._start_processing()
+
+ @abc.abstractmethod
+ def _start_processing(self):
+ """Start processing push notifications."""
raise NotImplementedError()
@abc.abstractmethod
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index 64a35c1994..11a97b8df4 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -22,7 +22,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher
from synapse.push.mailer import Mailer
-from synapse.types import RoomStreamToken
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
@@ -93,20 +92,6 @@ class EmailPusher(Pusher):
pass
self.timed_call = None
- def on_new_notifications(self, max_token: RoomStreamToken) -> None:
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- if self.max_stream_ordering:
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering
- )
- else:
- self.max_stream_ordering = max_stream_ordering
- self._start_processing()
-
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
@@ -172,7 +157,6 @@ class EmailPusher(Pusher):
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- assert self.max_stream_ordering is not None
unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
self.user_id, start, self.max_stream_ordering
)
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 5408aa1295..e8b25bcd2a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -26,7 +26,6 @@ from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import Pusher, PusherConfigException
-from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
@@ -122,17 +121,6 @@ class HttpPusher(Pusher):
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_token: RoomStreamToken) -> None:
- # We just use the minimum stream ordering and ignore the vector clock
- # component. This is safe to do as long as we *always* ignore the vector
- # clock components.
- max_stream_ordering = max_token.stream
-
- self.max_stream_ordering = max(
- max_stream_ordering, self.max_stream_ordering or 0
- )
- self._start_processing()
-
def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
@@ -192,10 +180,7 @@ class HttpPusher(Pusher):
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
-
- fn = self.store.get_unread_push_actions_for_user_in_range_for_http
- assert self.max_stream_ordering is not None
- unprocessed = await fn(
+ unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index 748b1407c6..409e46da8f 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -129,9 +129,8 @@ class PusherPool:
)
# create the pusher setting last_stream_ordering to the current maximum
- # stream ordering in event_push_actions, so it will process
- # pushes from this point onwards.
- last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
+ # stream ordering, so it will process pushes from this point onwards.
+ last_stream_ordering = self.store.get_room_max_stream_ordering()
await self.store.add_pusher(
user_id=user_id,
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py
index 2e56dfaf31..e5c03cc609 100644
--- a/synapse/storage/databases/main/event_push_actions.py
+++ b/synapse/storage/databases/main/event_push_actions.py
@@ -894,16 +894,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"])
return push_actions
- async def get_latest_push_action_stream_ordering(self):
- def f(txn):
- txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
- return txn.fetchone()
-
- result = await self.db_pool.runInteraction(
- "get_latest_push_action_stream_ordering", f
- )
- return result[0] or 0
-
def _remove_old_push_actions_before_txn(
self, txn, room_id, user_id, stream_ordering
):
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index d0452e1490..0b24b89a2e 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 1d99a45436..464e569ac8 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -15,7 +15,7 @@
import json
from urllib.parse import parse_qs, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
import pymacaroons
@@ -23,7 +23,7 @@ from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
-from tests.test_utils import FakeResponse
+from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
@@ -82,16 +82,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
-
-
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -160,6 +150,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
+ return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -374,26 +365,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
+ "username": username,
}
- user_id = "@foo:domain.org"
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -401,64 +383,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
+ request = self._build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.handler,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
self.handler._exchange_code.reset_mock()
self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
self.handler._fetch_userinfo.reset_mock()
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
self.handler._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
@@ -609,72 +581,55 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ session = self.handler._generate_oidc_session_token(
state=state,
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = "10.0.0.1"
- request.get_user_agent.return_value = "Browser"
+ request = self._build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test", request, client_redirect_url, {"phone": "1234567"},
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", ANY, ANY, {}
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", ANY, ANY, {}
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -683,14 +638,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(
- str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@@ -702,26 +654,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -732,13 +684,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, {},
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -755,14 +705,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
self.assertTrue(
- str(e.value).startswith(
+ args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -773,28 +720,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", ANY, ANY, {},
)
- self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- userinfo = {
- "sub": "test2",
- "username": "föö",
- }
- token = {}
-
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ self._make_callback_with_userinfo({"sub": "test2", "username": "föö"})
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
@@ -807,6 +741,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -815,14 +752,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
- )
+ self._make_callback_with_userinfo(userinfo)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", ANY, ANY, {},
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -838,12 +774,70 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self._make_callback_with_userinfo(userinfo)
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ def _make_callback_with_userinfo(
+ self, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+ ) -> None:
+ self.handler._exchange_code = simple_async_mock(return_value={})
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ state = "state"
+ session = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ request = self._build_callback_request("code", state, session)
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ def _build_callback_request(
+ self,
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+ ):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..8d50265145 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index d21e5588ca..69927cf6be 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from mock import Mock
+
import attr
from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
+from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
@@ -44,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -111,15 +117,22 @@ class SamlHandlerTestCase(HomeserverTestCase):
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- # The redirect_url doesn't matter with the default user mapping provider.
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri"
)
- self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -129,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, ""
)
- self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, ""
)
- self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- redirect_url = ""
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
+
+ # send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", request, ""
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -188,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
)
+ auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -208,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
+ request = _mock_request()
e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
+ self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "get_user_agent"])
diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py
index f6e7e5fdaa..48a74e2eee 100644
--- a/tests/logging/test_terse_json.py
+++ b/tests/logging/test_terse_json.py
@@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"""
handler = logging.StreamHandler(self.output)
handler.setFormatter(JsonFormatter())
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
logger = self.get_logger(handler)
- with LoggingContext() as context_one:
- context_one.request = "test"
+ with LoggingContext(request="test"):
logger.info("Hello there, %s!", "wally")
log = self.get_log_line()
@@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase):
"level",
"namespace",
"request",
- "scope",
]
self.assertCountEqual(log.keys(), expected_log_keys)
self.assertEqual(log["log"], "Hello there, wally!")
self.assertEqual(log["request"], "test")
- self.assertIsNone(log["scope"])
diff --git a/tests/test_federation.py b/tests/test_federation.py
index fa45f8b3b7..fc9aab32d0 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
}
)
- with LoggingContext(request="lying_event"):
+ with LoggingContext():
failure = self.get_failure(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 6873d45eb6..43898d8142 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,8 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+from mock import Mock
+
import attr
from twisted.python.failure import Failure
@@ -87,6 +89,16 @@ def setup_awaitable_errors() -> Callable[[], None]:
return cleanup
+def simple_async_mock(return_value=None, raises=None) -> Mock:
+ # AsyncMock is not available in python3.5, this mimics part of its behaviour
+ async def cb(*args, **kwargs):
+ if raises:
+ raise raises
+ return return_value
+
+ return Mock(side_effect=cb)
+
+
@attr.s
class FakeResponse:
"""A fake twisted.web.IResponse object
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index fdfb840b62..52ae5c5713 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -48,7 +48,7 @@ def setup_logging():
handler = ToTwistedHandler()
formatter = logging.Formatter(log_format)
handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
+ handler.addFilter(LoggingContextFilter())
root_logger.addHandler(handler)
log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR")
|