diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index af24c4984d..57bfbd7734 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
- TransactionOneTimeKeyCounts,
+ TransactionOneTimeKeysCount,
TransactionUnusedFallbackKeys,
)
from synapse.handlers.appservice import ApplicationServicesHandler
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [], {})),
- make_awaitable((1, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {})),
+ make_awaitable((1, {event.event_id: 0})),
+ ]
+ self.mock_store.get_events_as_list.side_effect = [
+ make_awaitable([]),
+ make_awaitable([event]),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {event.event_id: 0})),
]
-
+ self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
]
@@ -386,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
@@ -412,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def _notify_interested_services(self):
+ # This is normally set in `notify_interested_services` but we need to call the
+ # internal async version so the reactor gets pushed to completion.
+ self.hs.get_application_service_handler().current_max += 1
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services(
+ RoomStreamToken(
+ None, self.hs.get_application_service_handler().current_max
+ )
+ )
+ )
+
+ @parameterized.expand(
+ [
+ ("@local_as_user:test", True),
+ # Defining remote users in an application service user namespace regex is a
+ # footgun since the appservice might assume that it'll receive all events
+ # sent by that remote user, but it will only receive events in rooms that
+ # are shared with a local user. So we just remove this footgun possibility
+ # entirely and we won't notify the application service based on remote
+ # users.
+ ("@remote_as_user:remote", False),
+ ]
+ )
+ def test_match_interesting_room_members(
+ self, interesting_user: str, should_notify: bool
+ ):
+ """
+ Test to make sure that a interesting user (local or remote) in the room is
+ notified as expected when someone else in the room sends a message.
+ """
+ # Register an application service that's interested in the `interesting_user`
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": interesting_user,
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # Join the interesting user to the room
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, interesting_user, "join"
+ )
+ )
+ # Kick the appservice into checking this membership event to get the event out
+ # of the way
+ self._notify_interested_services()
+ # We don't care about the interesting user join event (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from an uninteresting user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from uninteresting user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ if should_notify:
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Even though the message came from an uninteresting user, it should still
+ # notify us because the interesting user is joined to the room where the
+ # message was sent.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+ else:
+ self.send_mock.assert_not_called()
+
+ def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ """
+ Test to make sure that a messages sent from a local user can be interesting and
+ picked up by the appservice.
+ """
+ # Register an application service that's interested in all local users
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": ".*",
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # We don't care about interesting events before this (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from the interesting local user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from interesting local user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Events sent from an interesting local user should also be picked up as
+ # interesting to the appservice.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+
def test_sending_read_receipt_batches_to_application_services(self):
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
@@ -609,7 +765,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
fake_device_ids = [f"device_{num}" for num in range(number_of_messages - 1)]
messages = {
self.exclusive_as_user: {
- device_id: to_device_message_content for device_id in fake_device_ids
+ device_id: {
+ "type": "test_to_device_message",
+ "sender": "@some:sender",
+ "content": to_device_message_content,
+ }
+ for device_id in fake_device_ids
}
}
@@ -967,7 +1128,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
# Capture what was sent as an AS transaction.
self.send_mock.assert_called()
last_args, _last_kwargs = self.send_mock.call_args
- otks: Optional[TransactionOneTimeKeyCounts] = last_args[self.ARG_OTK_COUNTS]
+ otks: Optional[TransactionOneTimeKeysCount] = last_args[self.ARG_OTK_COUNTS]
unused_fallbacks: Optional[TransactionUnusedFallbackKeys] = last_args[
self.ARG_FALLBACK_KEYS
]
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest.mock import Mock
import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
+ def token_login(self, token: str) -> Optional[str]:
+ body = {
+ "type": "m.login.token",
+ "token": token,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/login",
+ body,
+ )
+
+ if channel.code == 200:
+ return channel.json_body["user_id"]
+
+ return None
+
def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ def test_login_token_gives_user_id(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+ res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id)
- self.assertEqual("", res.auth_provider_id)
+ self.assertEqual(None, res.auth_provider_id)
- # when we advance the clock, the token should be rejected
- self.reactor.advance(6)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(token),
- AuthError,
+ def test_login_token_reuse_fails(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- def test_short_term_login_token_gives_auth_provider(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, auth_provider_id="my_idp"
- )
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
- self.assertEqual(self.user1, res.user_id)
- self.assertEqual("my_idp", res.auth_provider_id)
+ self.get_success(self.auth_handler.consume_login_token(token))
- def test_short_term_login_token_cannot_replace_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.get_failure(
+ self.auth_handler.consume_login_token(token),
+ AuthError,
)
- macaroon = pymacaroons.Macaroon.deserialize(token)
- res = self.get_success(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+ def test_login_token_expires(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- self.assertEqual(self.user1, res.user_id)
-
- # add another "user_id" caveat, which might allow us to override the
- # user_id.
- macaroon.add_first_party_caveat("user_id = b_user")
+ # when we advance the clock, the token should be rejected
+ self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+ self.auth_handler.consume_login_token(token),
AuthError,
)
+ def test_login_token_gives_auth_provider(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ auth_provider_id="my_idp",
+ auth_provider_session_id="11-22-33-44",
+ duration_ms=(5 * 1000),
+ )
+ )
+ res = self.get_success(self.auth_handler.consume_login_token(token))
+ self.assertEqual(self.user1, res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+ self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
+
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
),
ResourceLimitError,
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
# If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1, device_id=None, valid_until_ms=None
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
- )
-
- def _get_macaroon(self) -> pymacaroons.Macaroon:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
- return pymacaroons.Macaroon.deserialize(token)
+ self.assertIsNotNone(self.token_login(token))
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.store = hs.get_datastores().main
return hs
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)
- retrieved_device_id, device_data = self.get_success(
- self.handler.get_dehydrated_device(user_id=user_id)
- )
+ result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+ assert result is not None
+ retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 745750b1d7..d00c69c229 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -19,7 +19,13 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
@@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
+ def test_backfill_ignores_known_events(self) -> None:
+ """
+ Tests that events that we already know about are ignored when backfilling.
+ """
+ # Set up users
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # Create a room to backfill events into
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # Build an event to backfill
+ event = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"body": "hello world", "msgtype": "m.text"},
+ "room_id": room_id,
+ "sender": other_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ # Ensure the event is not already in the DB
+ self.get_failure(
+ self.store.get_event(event.event_id),
+ NotFoundError,
+ )
+
+ # Backfill the event and check that it has entered the DB.
+
+ # We mock out the FederationClient.backfill method, to pretend that a remote
+ # server has returned our fake event.
+ federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock
+
+ # We also mock the persist method with a side effect of itself. This allows us
+ # to track when it has been called while preserving its function.
+ persist_events_and_notify_mock = Mock(
+ side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
+ )
+ self.hs.get_federation_event_handler().persist_events_and_notify = (
+ persist_events_and_notify_mock
+ )
+
+ # Small side-tangent. We populate the event cache with the event, even though
+ # it is not yet in the DB. This is an invalid scenario that can currently occur
+ # due to not properly invalidating the event cache.
+ # See https://github.com/matrix-org/synapse/issues/13476.
+ #
+ # As a result, backfill should not rely on the event cache to check whether
+ # we already have an event in the DB.
+ # TODO: Remove this bit when the event cache is properly invalidated.
+ cache_entry = EventCacheEntry(
+ event=event,
+ redacted_event=None,
+ )
+ self.store._get_event_cache.set_local((event.event_id,), cache_entry)
+
+ # We now call FederationEventHandler.backfill (a separate method) to trigger
+ # a backfill request. It should receive the fake event.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ )
+ )
+
+ # Check that our fake event was persisted.
+ persist_events_and_notify_mock.assert_called_once()
+ persist_events_and_notify_mock.reset_mock()
+
+ # Now we repeat the backfill, having the homeserver receive the fake event
+ # again.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ ),
+ )
+
+ # This time, we expect no event persistence to have occurred, as we already
+ # have this event.
+ persist_events_and_notify_mock.assert_not_called()
+
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 918010cddb..e448cb1901 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -14,7 +14,7 @@
from typing import Optional
from unittest import mock
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, StoreError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
check_state_dependent_auth_rules,
@@ -43,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
def make_homeserver(self, reactor, clock):
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
- spec=["get_room_state_ids", "get_room_state", "get_event"]
+ spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
@@ -459,6 +459,203 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later(
+ self,
+ ) -> None:
+ """
+ Test to make sure we backoff and don't try to fetch a missing prev_event when we
+ already know it has a invalid signature from checking the signatures of all of
+ the events in the backfill response.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # Add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event_without_signatures = make_event_from_dict(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event_without_signatures"},
+ },
+ room_version,
+ )
+
+ # Create a regular event that should pass except for the
+ # `pulled_event_without_signatures` in the `prev_event`.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ member_event.event_id,
+ pulled_event_without_signatures.event_id,
+ ],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event"},
+ }
+ ),
+ room_version,
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self.mock_federation_transport_client.backfill.return_value = make_awaitable(
+ {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
+ )
+
+ # Keep track of the count and make sure we don't make any of these requests
+ event_endpoint_requested_count = 0
+ room_state_ids_endpoint_requested_count = 0
+ room_state_endpoint_requested_count = 0
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> None:
+ nonlocal event_endpoint_requested_count
+ event_endpoint_requested_count += 1
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_ids_endpoint_requested_count
+ room_state_ids_endpoint_requested_count += 1
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_endpoint_requested_count
+ room_state_endpoint_requested_count += 1
+
+ # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in
+ # the happy path but if the logic is sneaking around what we expect, stub that
+ # out so we can detect that failure
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ # The function under test: try to backfill and process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ self.OTHER_SERVER_NAME,
+ room_id,
+ limit=1,
+ extremities=["$some_extremity"],
+ )
+ )
+
+ if event_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /event in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_ids_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state_ids in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ # Make sure we only recorded a single failure which corresponds to the signature
+ # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before
+ # we process all of the pulled events.
+ backfill_num_attempts_for_event_without_signatures = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event_without_signatures.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1)
+
+ # And make sure we didn't record a failure for the event that has the missing
+ # prev_event because we don't want to cause a cascade of failures. Not being
+ # able to fetch the `prev_events` just means we won't be able to de-outlier the
+ # pulled event. But we can still use an `outlier` in the state/auth chain for
+ # another event. So we shouldn't stop a downstream event from trying to pull it.
+ self.get_failure(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ ),
+ # StoreError: 404: No row found
+ StoreError,
+ )
+
def test_process_pulled_event_with_rejected_missing_state(self) -> None:
"""Ensure that we correctly handle pulled events with missing state containing a
rejected state event
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -22,12 +21,15 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
# config for common cases
DEFAULT_CONFIG = {
"enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
+ "authorization_endpoint": ISSUER + "authorize",
+ "token_endpoint": ISSUER + "token",
+ "jwks_uri": ISSUER + "jwks",
}
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url: str) -> JsonDict:
- # Mock get_json calls to handle jwks & oidc discovery endpoints
- if url == WELL_KNOWN:
- # Minimal discovery document, as defined in OpenID.Discovery
- # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
- return {
- "issuer": ISSUER,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
- "userinfo_endpoint": USERINFO_ENDPOINT,
- "response_types_supported": ["code"],
- "subject_types_supported": ["public"],
- "id_token_signing_alg_values_supported": ["RS256"],
- }
- elif url == JWKS_URI:
- return {"keys": []}
-
- return {}
-
-
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = b"Synapse Test"
+ self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
- hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+ hs = self.setup_test_homeserver()
+ self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+ self.hs_patcher.start()
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
+ auth_handler = hs.get_auth_handler()
+ # Mock the complete SSO login method.
+ self.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
+
return hs
+ def tearDown(self) -> None:
+ self.hs_patcher.stop()
+ return super().tearDown()
+
+ def reset_mocks(self):
+ """Reset all the Mocks."""
+ self.fake_server.reset_mocks()
+ self.render_error.reset_mock()
+ self.complete_sso_login.reset_mock()
+
def metadata_edit(self, values):
"""Modify the result that will be returned by the well-known query"""
- async def patched_get_json(uri):
- res = await get_json(uri)
- if uri == WELL_KNOWN:
- res.update(values)
- return res
+ metadata = self.fake_server.get_metadata()
+ metadata.update(values)
+ return patch.object(self.fake_server, "get_metadata", return_value=metadata)
- return patch.object(self.http_client, "get_json", patched_get_json)
+ def start_authorization(
+ self,
+ userinfo: dict,
+ client_redirect_url: str = "http://client/redirect",
+ scope: str = "openid",
+ with_sid: bool = False,
+ ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+ """Start an authorization request, and get the callback request back."""
+ nonce = random_string(10)
+ state = random_string(10)
+
+ code, grant = self.fake_server.start_authorization(
+ userinfo=userinfo,
+ scope=scope,
+ client_id=self.provider._client_auth.client_id,
+ redirect_uri=self.provider._callback_url,
+ nonce=nonce,
+ with_sid=with_sid,
+ )
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.fake_server.get_metadata_handler.assert_called_once()
- self.assertEqual(metadata.issuer, ISSUER)
- self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
- self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
- self.assertEqual(metadata.jwks_uri, JWKS_URI)
- # FIXME: it seems like authlib does not have that defined in its metadata models
- # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ self.fake_server.authorization_endpoint,
+ )
+ self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+ self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
# subsequent calls should be cached
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
- @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
- self.assertEqual(jwks, {"keys": []})
+ self.fake_server.get_jwks_handler.assert_called_once()
+ self.assertEqual(jwks, self.fake_server.get_jwks())
# subsequent calls should be cached…
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_jwks_handler.assert_not_called()
# …unless forced
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.fake_server.get_jwks_handler.assert_called_once()
- # Throw if the JWKS uri is missing
- original = self.provider.load_metadata
-
- async def patched_load_metadata():
- m = (await original()).copy()
- m.update({"jwks_uri": None})
- return m
-
- with patch.object(self.provider, "load_metadata", patched_load_metadata):
+ with self.metadata_edit({"jwks_uri": None}):
+ # If we don't do this, the load_metadata call will throw because of the
+ # missing jwks_uri
+ self.provider._user_profile_method = "userinfo_endpoint"
+ self.get_success(self.provider.load_metadata(force=True))
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
- auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+ auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
username = "bar"
userinfo = {
"sub": "foo",
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
- code = "code"
- state = "state"
- nonce = "nonce"
client_redirect_url = "http://client/redirect"
- ip_address = "10.0.0.1"
- session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
- request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+ request, _ = self.start_authorization(
+ userinfo, client_redirect_url=client_redirect_url
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_not_called()
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
+ request, _ = self.start_authorization(userinfo)
with patch.object(
self.provider,
"_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- auth_handler.complete_sso_login.reset_mock()
- self.provider._exchange_code.reset_mock()
- self.provider._parse_id_token.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ self.reset_mocks()
# With userinfo fetching
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- }
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ # Without the "openid" scope, the FakeProvider does not generate an id_token
+ request, _ = self.start_authorization(userinfo, scope="")
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_not_called()
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
+ self.reset_mocks()
+
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- "id_token": "id_token",
- }
- id_token = {
- "sid": "abcdefgh",
- }
- self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- auth_handler.complete_sso_login.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ request, grant = self.start_authorization(userinfo, with_sid=True)
+ self.assertIsNotNone(grant.sid)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
- auth_provider_session_id=id_token["sid"],
+ auth_provider_session_id=grant.sid,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(userinfo=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
- # Handle code exchange failure
- from synapse.handlers.oidc import OidcError
-
- self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
- raises=OidcError("invalid_request")
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("invalid_request")
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(token=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("server_error")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
- token = {"type": "bearer"}
- token_json = json.dumps(token).encode("utf-8")
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
# Test error handling
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad Request",
- body=b'{"error": "foo", "error_description": "bar"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b"Not JSON",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse(
+ code=500, body=b"Not JSON"
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b'{"error": "internal_server_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=500, payload={"error": "internal_server_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad request",
- body=b"{}",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200,
- phrase=b"OK",
- body=b'{"error": "some_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=200, payload={"error": "some_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# the client secret provided to the should be a jwt which can be checked with
# the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
- state = "state"
- client_redirect_url = "http://client/redirect"
- session = self._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- )
- request = _build_callback_request("code", state, session)
-
+ request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@foo:test",
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
{"phone": "1234567"},
new_user=True,
auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Test if the mxid is already taken
store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Subsequent calls should map to the same mxid.
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
- self.get_success(
- _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
- )
+ userinfo = {"sub": "test2", "username": "föö"}
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
store = self.hs.get_datastores().main
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@test_user1:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# userinfo lacking "test": "foobar" attribute should fail.
userinfo = {
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": "foobar" attribute should succeed.
userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": "foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
userinfo = {
"sub": "tester",
"username": "tester",
"test": ["foobar", "foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
Test that auth fails if attributes exist but don't match,
or are non-string values.
"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": ["foo", "bar"] attribute should fail
userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": ["foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": False attribute should fail
# this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": False,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": None attribute should fail
# a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 1 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 1,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 3.14 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 3.14,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
def _generate_oidc_session_token(
self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return self.handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
- idp_id="oidc",
+ idp_id=self.provider.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
-async def _make_callback_with_userinfo(
- hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
- """Mock up an OIDC callback with the given userinfo dict
-
- We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
- and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
- Args:
- hs: the HomeServer impl to send the callback to.
- userinfo: the OIDC userinfo dict
- client_redirect_url: the URL to redirect to on success.
- """
-
- handler = hs.get_oidc_handler()
- provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
- provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
-
- state = "state"
- session = handler._macaroon_generator.generate_oidc_session_token(
- state=state,
- session_data=OidcSessionData(
- idp_id="oidc",
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id="",
- ),
- )
- request = _build_callback_request("code", state, session)
-
- await handler.handle_oidc_callback(request)
-
-
def _build_callback_request(
code: str,
state: str,
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c96dc6caf2..584e7b8971 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -15,6 +15,7 @@
from typing import Optional
from unittest.mock import Mock, call
+from parameterized import parameterized
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
@@ -37,6 +38,7 @@ from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
@@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(state, new_state)
-class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor, clock, hs):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
# our status message should be the same as it was before
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_keeps_busy(self):
- """Test that presence set by syncing doesn't affect busy status"""
- # while this isn't the default
- self.presence_handler._busy_presence_enabled = True
+ @parameterized.expand([(False,), (True,)])
+ @unittest.override_config(
+ {
+ "experimental_features": {
+ "msc3026_enabled": True,
+ },
+ }
+ )
+ def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ """Test that presence set by syncing doesn't affect busy status
+ Args:
+ test_with_workers: If True, check the presence state of the user by calling
+ /sync against a worker, rather than the main process.
+ """
user_id = "@test:server"
status_msg = "I'm busy!"
+ # By default, we call /sync against the main process.
+ worker_to_sync_against = self.hs
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+ )
+
+ # Set presence to BUSY
self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+ # Perform a sync with a presence state other than busy. This should NOT change
+ # our presence status; we only change from busy if we explicitly set it via
+ # /presence/*.
self.get_success(
- self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+ worker_to_sync_against.get_presence_handler().user_syncing(
+ user_id, True, PresenceState.ONLINE
+ )
)
+ # Check against the main process that the user's presence did not change.
state = self.get_success(
self.presence_handler.get_state(UserID.from_string(user_id))
)
@@ -963,7 +992,8 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def default_config(self):
config = super().default_config()
- config["send_federation"] = True
+ # Enable federation sending on the main process.
+ config["federation_sender_instances"] = None
return config
def prepare(self, reactor, clock, hs):
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f88c725a42..675aa023ac 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,6 +14,8 @@
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.types
@@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(res)
+ @unittest.override_config(
+ {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
+ )
+ def test_avatar_constraint_on_local_server_with_port(self):
+ """Test that avatar metadata is correctly fetched when the media is on a local
+ server and the server has an explicit port.
+
+ (This was previously a bug)
+ """
+ local_server_name = self.hs.config.server.server_name
+ media_id = "local"
+ local_mxc = f"mxc://{local_server_name}/{media_id}"
+
+ # mock up the existence of the avatar file
+ self._setup_local_files({media_id: {"mimetype": "image/png"}})
+
+ # and now check that check_avatar_size_and_mime_type is happy
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc))
+ )
+
+ @parameterized.expand([("remote",), ("remote:1234",)])
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None:
+ """Test that avatar metadata is correctly fetched from a remote server"""
+ media_id = "remote"
+ remote_mxc = f"mxc://{remote_server_name}/{media_id}"
+
+ # if the media is remote, check_avatar_size_and_mime_type just checks the
+ # media cache, so we don't need to instantiate a real remote server. It is
+ # sufficient to poke an entry into the db.
+ self.get_success(
+ self.hs.get_datastores().main.store_cached_remote_media(
+ media_id=media_id,
+ media_type="image/png",
+ media_length=50,
+ origin=remote_server_name,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ filesystem_id="xyz",
+ )
+ )
+
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
+ )
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
"""Stores metadata about files in the database.
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
new file mode 100644
index 0000000000..137deab138
--- /dev/null
+++ b/tests/handlers/test_sso.py
@@ -0,0 +1,145 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import BinaryIO, Callable, Dict, List, Optional, Tuple
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.client import RawHeaders
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import SMALL_PNG, FakeResponse
+
+
+class TestSSOHandler(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ self.http_client.user_agent = b"Synapse Test"
+ hs = self.setup_test_homeserver(
+ proxied_blacklisted_http_client=self.http_client
+ )
+ return hs
+
+ async def test_set_avatar(self) -> None:
+ """Tests successfully setting the avatar of a newly created user"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # Ensure avatar is set on this newly created user,
+ # so no need to compare for the exact image
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertIsNot(profile["avatar_url"], None)
+
+ @unittest.override_config({"max_avatar_size": 1})
+ async def test_set_avatar_too_big_image(self) -> None:
+ """Tests that saving an avatar fails when it is too big"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]})
+ async def test_set_avatar_incorrect_mime_type(self) -> None:
+ """Tests that saving an avatar fails when its mime type is not allowed"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ async def test_skip_saving_avatar_when_not_changed(self) -> None:
+ """Tests whether saving of avatar correctly skips if the avatar hasn't
+ changed"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ # set avatar for the first time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # get avatar picture for comparison after another attempt
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ url_to_match = profile["avatar_url"]
+
+ # set same avatar for the second time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # compare avatar picture's url from previous step
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertEqual(profile["avatar_url"], url_to_match)
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+
+ fake_response = FakeResponse(code=404)
+ if url == "http://my.server/me.png":
+ fake_response = FakeResponse(
+ code=200,
+ headers=Headers(
+ {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]}
+ ),
+ body=SMALL_PNG,
+ )
+
+ if max_size is not None and max_size < len(SMALL_PNG):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
+ )
+
+ if is_allowed_content_type and not is_allowed_content_type("image/png"):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ (
+ "Requested file's content type not allowed for this operation: %s"
+ % "image/png"
+ ),
+ )
+
+ output_stream.write(fake_response.body)
+
+ return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 9c821b3042..efbb5a8dbb 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -200,7 +200,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
],
)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_started_typing_remote_send(self) -> None:
self.room_members = [U_APPLE, U_ONION]
@@ -305,7 +306,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(events[0], [])
self.assertEqual(events[1], 0)
- @override_config({"send_federation": True})
+ # Enable federation sending on the main process.
+ @override_config({"federation_sender_instances": None})
def test_stopped_typing(self) -> None:
self.room_members = [U_APPLE, U_BANANA, U_ONION]
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 9e39cd97e5..75fc5a17a4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -56,7 +56,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below.
+ config["update_user_directory_from_worker"] = None
self.appservice = ApplicationService(
token="i_am_an_app_service",
@@ -1045,7 +1046,9 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
+ # Re-enables updating the user directory, as that function is needed below. It
+ # will be force disabled later
+ config["update_user_directory_from_worker"] = None
hs = self.setup_test_homeserver(config=config)
self.config = hs.config
|