diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index b17af2725b..144e49d0fd 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -22,7 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
-from synapse.api.constants import EduTypes
+from synapse.api.constants import EduTypes, EventTypes
from synapse.appservice import (
ApplicationService,
TransactionOneTimeKeyCounts,
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, make_awaitable, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -76,9 +76,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [], {})),
- make_awaitable((1, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {})),
+ make_awaitable((1, {event.event_id: 0})),
+ ]
+ self.mock_store.get_events_as_list.side_effect = [
+ make_awaitable([]),
+ make_awaitable([event]),
]
self.handler.notify_interested_services(RoomStreamToken(None, 1))
@@ -95,10 +99,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
+ make_awaitable((0, {event.event_id: 0})),
]
-
+ self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,7 +116,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_events_stream.side_effect = [
+ self.mock_store.get_all_new_event_ids_stream.side_effect = [
make_awaitable((0, [event], {event.event_id: 0})),
]
@@ -386,15 +390,16 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipts.register_servlets,
]
- def prepare(self, reactor, clock, hs):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock(
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
return_value=self._services
)
@@ -412,6 +417,157 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def _notify_interested_services(self):
+ # This is normally set in `notify_interested_services` but we need to call the
+ # internal async version so the reactor gets pushed to completion.
+ self.hs.get_application_service_handler().current_max += 1
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services(
+ RoomStreamToken(
+ None, self.hs.get_application_service_handler().current_max
+ )
+ )
+ )
+
+ @parameterized.expand(
+ [
+ ("@local_as_user:test", True),
+ # Defining remote users in an application service user namespace regex is a
+ # footgun since the appservice might assume that it'll receive all events
+ # sent by that remote user, but it will only receive events in rooms that
+ # are shared with a local user. So we just remove this footgun possibility
+ # entirely and we won't notify the application service based on remote
+ # users.
+ ("@remote_as_user:remote", False),
+ ]
+ )
+ def test_match_interesting_room_members(
+ self, interesting_user: str, should_notify: bool
+ ):
+ """
+ Test to make sure that a interesting user (local or remote) in the room is
+ notified as expected when someone else in the room sends a message.
+ """
+ # Register an application service that's interested in the `interesting_user`
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": interesting_user,
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # Join the interesting user to the room
+ self.get_success(
+ event_injection.inject_member_event(
+ self.hs, room_id, interesting_user, "join"
+ )
+ )
+ # Kick the appservice into checking this membership event to get the event out
+ # of the way
+ self._notify_interested_services()
+ # We don't care about the interesting user join event (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from an uninteresting user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from uninteresting user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ if should_notify:
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Even though the message came from an uninteresting user, it should still
+ # notify us because the interesting user is joined to the room where the
+ # message was sent.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+ else:
+ self.send_mock.assert_not_called()
+
+ def test_application_services_receive_events_sent_by_interesting_local_user(self):
+ """
+ Test to make sure that a messages sent from a local user can be interesting and
+ picked up by the appservice.
+ """
+ # Register an application service that's interested in all local users
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": ".*",
+ "exclusive": False,
+ },
+ ],
+ },
+ )
+
+ # Create a room
+ alice = self.register_user("alice", "pass")
+ alice_access_token = self.login("alice", "pass")
+ room_id = self.helper.create_room_as(room_creator=alice, tok=alice_access_token)
+
+ # We don't care about interesting events before this (this test is making sure
+ # the next thing works)
+ self.send_mock.reset_mock()
+
+ # Send a message from the interesting local user
+ self.helper.send_event(
+ room_id,
+ type=EventTypes.Message,
+ content={
+ "msgtype": "m.text",
+ "body": "message from interesting local user",
+ },
+ tok=alice_access_token,
+ )
+ # Kick the appservice into checking this new event
+ self._notify_interested_services()
+
+ self.send_mock.assert_called_once()
+ (
+ service,
+ events,
+ _ephemeral,
+ _to_device_messages,
+ _otks,
+ _fbks,
+ _device_list_summary,
+ ) = self.send_mock.call_args[0]
+
+ # Events sent from an interesting local user should also be picked up as
+ # interesting to the appservice.
+ self.assertEqual(service, interested_appservice)
+ self.assertEqual(events[0]["type"], "m.room.message")
+ self.assertEqual(events[0]["sender"], alice)
+
def test_sending_read_receipt_batches_to_application_services(self):
"""Tests that a large batch of read receipts are sent correctly to
interested application services.
@@ -447,6 +603,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipt_type="m.read",
user_id=self.local_user,
event_ids=[f"$eventid_{i}"],
+ thread_id=None,
data={},
)
)
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest.mock import Mock
import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
+from synapse.rest.client import login
from synapse.server import HomeServer
from synapse.util import Clock
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
+ login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
+ def token_login(self, token: str) -> Optional[str]:
+ body = {
+ "type": "m.login.token",
+ "token": token,
+ }
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/login",
+ body,
+ )
+
+ if channel.code == 200:
+ return channel.json_body["user_id"]
+
+ return None
+
def test_macaroon_caveats(self) -> None:
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
- def test_short_term_login_token_gives_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ def test_login_token_gives_user_id(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+ res = self.get_success(self.auth_handler.consume_login_token(token))
self.assertEqual(self.user1, res.user_id)
- self.assertEqual("", res.auth_provider_id)
+ self.assertEqual(None, res.auth_provider_id)
- # when we advance the clock, the token should be rejected
- self.reactor.advance(6)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(token),
- AuthError,
+ def test_login_token_reuse_fails(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- def test_short_term_login_token_gives_auth_provider(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, auth_provider_id="my_idp"
- )
- res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
- self.assertEqual(self.user1, res.user_id)
- self.assertEqual("my_idp", res.auth_provider_id)
+ self.get_success(self.auth_handler.consume_login_token(token))
- def test_short_term_login_token_cannot_replace_user_id(self) -> None:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ self.get_failure(
+ self.auth_handler.consume_login_token(token),
+ AuthError,
)
- macaroon = pymacaroons.Macaroon.deserialize(token)
- res = self.get_success(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+ def test_login_token_expires(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ duration_ms=(5 * 1000),
+ )
)
- self.assertEqual(self.user1, res.user_id)
-
- # add another "user_id" caveat, which might allow us to override the
- # user_id.
- macaroon.add_first_party_caveat("user_id = b_user")
+ # when we advance the clock, the token should be rejected
+ self.reactor.advance(6)
self.get_failure(
- self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+ self.auth_handler.consume_login_token(token),
AuthError,
)
+ def test_login_token_gives_auth_provider(self) -> None:
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(
+ self.user1,
+ auth_provider_id="my_idp",
+ auth_provider_session_id="11-22-33-44",
+ duration_ms=(5 * 1000),
+ )
+ )
+ res = self.get_success(self.auth_handler.consume_login_token(token))
+ self.assertEqual(self.user1, res.user_id)
+ self.assertEqual("my_idp", res.auth_provider_id)
+ self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
def test_mau_limits_disabled(self) -> None:
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
+
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
def test_mau_limits_parity(self) -> None:
# Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
),
ResourceLimitError,
)
- self.get_failure(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- ),
- ResourceLimitError,
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNone(self.token_login(token))
# If in monthly active cohort
self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1, device_id=None, valid_until_ms=None
)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
+ self.assertIsNotNone(self.token_login(token))
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.small_number_of_users)
)
- self.get_success(
- self.auth_handler.validate_short_term_login_token(
- self._get_macaroon().serialize()
- )
- )
-
- def _get_macaroon(self) -> pymacaroons.Macaroon:
- token = self.macaroon_generator.generate_short_term_login_token(
- self.user1, "", duration_in_ms=5000
+ token = self.get_success(
+ self.auth_handler.create_login_token_for_user_id(self.user1)
)
- return pymacaroons.Macaroon.deserialize(token)
+ self.assertIsNotNone(self.token_login(token))
diff --git a/tests/handlers/test_deactivate_account.py b/tests/handlers/test_deactivate_account.py
index 7586e472b5..bce65fab7d 100644
--- a/tests/handlers/test_deactivate_account.py
+++ b/tests/handlers/test_deactivate_account.py
@@ -11,8 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from http import HTTPStatus
-from typing import Any, Dict
from twisted.test.proto_helpers import MemoryReactor
@@ -21,6 +19,7 @@ from synapse.push.rulekinds import PRIORITY_CLASS_MAP
from synapse.rest import admin
from synapse.rest.client import account, login
from synapse.server import HomeServer
+from synapse.synapse_rust.push import PushRule
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -58,7 +57,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
access_token=self.token,
)
- self.assertEqual(req.code, HTTPStatus.OK, req)
+ self.assertEqual(req.code, 200, req)
def test_global_account_data_deleted_upon_deactivation(self) -> None:
"""
@@ -131,12 +130,12 @@ class DeactivateAccountTestCase(HomeserverTestCase):
),
)
- def _is_custom_rule(self, push_rule: Dict[str, Any]) -> bool:
+ def _is_custom_rule(self, push_rule: PushRule) -> bool:
"""
Default rules start with a dot: such as .m.rule and .im.vector.
This function returns true iff a rule is custom (not default).
"""
- return "/." not in push_rule["rule_id"]
+ return "/." not in push_rule.rule_id
def test_push_rules_deleted_upon_account_deactivation(self) -> None:
"""
@@ -158,32 +157,30 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
# Test the rule exists
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule made it
- self.assertEqual(
- push_rules,
- [
- {
- "user_name": "@user:test",
- "rule_id": "personal.override.rule1",
- "priority_class": 5,
- "priority": 0,
- "conditions": [],
- "actions": [],
- "default": False,
- }
- ],
- push_rules,
- )
+ self.assertEqual(len(push_rules), 1)
+ self.assertEqual(push_rules[0].rule_id, "personal.override.rule1")
+ self.assertEqual(push_rules[0].priority_class, 5)
+ self.assertEqual(push_rules[0].conditions, [])
+ self.assertEqual(push_rules[0].actions, [])
# Request the deactivation of our account
self._deactivate_my_account()
- push_rules = self.get_success(self._store.get_push_rules_for_user(self.user))
+ filtered_push_rules = self.get_success(
+ self._store.get_push_rules_for_user(self.user)
+ )
# Filter out default rules; we don't care
- push_rules = list(filter(self._is_custom_rule, push_rules))
+ push_rules = [
+ r for r, _ in filtered_push_rules.rules() if self._is_custom_rule(r)
+ ]
# Check our rule no longer exists
self.assertEqual(push_rules, [], push_rules)
@@ -322,3 +319,18 @@ class DeactivateAccountTestCase(HomeserverTestCase):
)
),
)
+
+ def test_deactivate_account_needs_auth(self) -> None:
+ """
+ Tests that making a request to /deactivate with an empty body
+ succeeds in starting the user-interactive auth flow.
+ """
+ req = self.make_request(
+ "POST",
+ "account/deactivate",
+ {},
+ access_token=self.token,
+ )
+
+ self.assertEqual(req.code, 401, req)
+ self.assertEqual(req.json_body["flows"], [{"stages": ["m.login.password"]}])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index b8b465d35b..ce7525e29c 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,7 +19,7 @@ from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
-from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
+from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
from synapse.util import Clock
@@ -32,7 +32,9 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.store = hs.get_datastores().main
return hs
@@ -61,6 +63,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
@@ -83,6 +86,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(res2, "fco")
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
@@ -95,6 +99,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
+ assert dev is not None
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
@@ -264,7 +269,9 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server", federation_http_client=None)
- self.handler = hs.get_device_handler()
+ handler = hs.get_device_handler()
+ assert isinstance(handler, DeviceHandler)
+ self.handler = handler
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -284,9 +291,9 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
)
)
- retrieved_device_id, device_data = self.get_success(
- self.handler.get_dehydrated_device(user_id=user_id)
- )
+ result = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+ assert result is not None
+ retrieved_device_id, device_data = result
self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 1e6ad4b663..95698bc275 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -891,6 +891,12 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
new_callable=mock.MagicMock,
return_value=make_awaitable(["some_room_id"]),
)
+ mock_get_users = mock.patch.object(
+ self.store,
+ "get_users_server_still_shares_room_with",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable({remote_user_id}),
+ )
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
@@ -898,7 +904,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(response_body),
)
- with mock_get_rooms, mock_request as mocked_federation_request:
+ with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
# Make the first query and sanity check it succeeds.
response_1 = self.get_success(
e2e_handler.query_devices(
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 745750b1d7..d00c69c229 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -19,7 +19,13 @@ from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
+from synapse.api.errors import (
+ AuthError,
+ Codes,
+ LimitExceededError,
+ NotFoundError,
+ SynapseError,
+)
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json
@@ -28,6 +34,7 @@ from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
+from synapse.storage.databases.main.events_worker import EventCacheEntry
from synapse.util import Clock
from synapse.util.stringutils import random_string
@@ -322,6 +329,102 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
self.get_success(d)
+ def test_backfill_ignores_known_events(self) -> None:
+ """
+ Tests that events that we already know about are ignored when backfilling.
+ """
+ # Set up users
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # Create a room to backfill events into
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+
+ # Build an event to backfill
+ event = event_from_pdu_json(
+ {
+ "type": EventTypes.Message,
+ "content": {"body": "hello world", "msgtype": "m.text"},
+ "room_id": room_id,
+ "sender": other_user,
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ # Ensure the event is not already in the DB
+ self.get_failure(
+ self.store.get_event(event.event_id),
+ NotFoundError,
+ )
+
+ # Backfill the event and check that it has entered the DB.
+
+ # We mock out the FederationClient.backfill method, to pretend that a remote
+ # server has returned our fake event.
+ federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock
+
+ # We also mock the persist method with a side effect of itself. This allows us
+ # to track when it has been called while preserving its function.
+ persist_events_and_notify_mock = Mock(
+ side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
+ )
+ self.hs.get_federation_event_handler().persist_events_and_notify = (
+ persist_events_and_notify_mock
+ )
+
+ # Small side-tangent. We populate the event cache with the event, even though
+ # it is not yet in the DB. This is an invalid scenario that can currently occur
+ # due to not properly invalidating the event cache.
+ # See https://github.com/matrix-org/synapse/issues/13476.
+ #
+ # As a result, backfill should not rely on the event cache to check whether
+ # we already have an event in the DB.
+ # TODO: Remove this bit when the event cache is properly invalidated.
+ cache_entry = EventCacheEntry(
+ event=event,
+ redacted_event=None,
+ )
+ self.store._get_event_cache.set_local((event.event_id,), cache_entry)
+
+ # We now call FederationEventHandler.backfill (a separate method) to trigger
+ # a backfill request. It should receive the fake event.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ )
+ )
+
+ # Check that our fake event was persisted.
+ persist_events_and_notify_mock.assert_called_once()
+ persist_events_and_notify_mock.reset_mock()
+
+ # Now we repeat the backfill, having the homeserver receive the fake event
+ # again.
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ other_user,
+ room_id,
+ limit=10,
+ extremities=[],
+ ),
+ )
+
+ # This time, we expect no event persistence to have occurred, as we already
+ # have this event.
+ persist_events_and_notify_mock.assert_not_called()
+
@unittest.override_config(
{"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 51c8dd6498..e448cb1901 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -11,14 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
from unittest import mock
+from synapse.api.errors import AuthError, StoreError
+from synapse.api.room_versions import RoomVersion
+from synapse.event_auth import (
+ check_state_dependent_auth_rules,
+ check_state_independent_auth_rules,
+)
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
+from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -34,7 +43,7 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
def make_homeserver(self, reactor, clock):
# mock out the federation transport client
self.mock_federation_transport_client = mock.Mock(
- spec=["get_room_state_ids", "get_room_state", "get_event"]
+ spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
@@ -227,3 +236,812 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
if prev_exists_as_outlier:
self.mock_federation_transport_client.get_event.assert_not_called()
+
+ def test_process_pulled_event_records_failed_backfill_attempts(
+ self,
+ ) -> None:
+ """
+ Test to make sure that failed backfill attempts for an event are
+ recorded in the `event_failed_pull_attempts` table.
+
+ In this test, we pretend we are processing a "pulled" event via
+ backfill. The pulled event has a fake `prev_event` which our server has
+ obviously never seen before so it attempts to request the state at that
+ `prev_event` which expectedly fails because it's a fake event. Because
+ the server can't fetch the state at the missing `prev_event`, the
+ "pulled" event fails the history check and is fails to process.
+
+ We check that we correctly record the number of failed pull attempts
+ of the pulled event and as a sanity check, that the "pulled" event isn't
+ persisted.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # We expect an outbound request to /state_ids, so stub that out
+ self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
+ {
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ "pdu_ids": [],
+ "auth_chain_ids": [],
+ }
+ )
+ # We also expect an outbound request to /state
+ self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
+ StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
+ )
+ )
+
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ # The fake prev event will make the pulled event fail
+ # the history check (`Unable to get missing prev_event
+ # $fake_prev_event`)
+ "$fake_prev_event"
+ ],
+ "auth_events": [],
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled"},
+ }
+ ),
+ room_version,
+ )
+
+ # The function under test: try to process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure our failed pull attempt was recorded
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 1)
+
+ # The function under test: try to process the pulled event again
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure our second failed pull attempt was recorded (`num_attempts` was incremented)
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 2)
+
+ # And as a sanity check, make sure the event was not persisted through all of this.
+ persisted = self.get_success(
+ main_store.get_event(pulled_event.event_id, allow_none=True)
+ )
+ self.assertIsNone(
+ persisted,
+ "pulled event that fails the history check should not be persisted at all",
+ )
+
+ def test_process_pulled_event_clears_backfill_attempts_after_being_successfully_persisted(
+ self,
+ ) -> None:
+ """
+ Test to make sure that failed pull attempts
+ (`event_failed_pull_attempts` table) for an event are cleared after the
+ event is successfully persisted.
+
+ In this test, we pretend we are processing a "pulled" event via
+ backfill. The pulled event succesfully processes and the backward
+ extremeties are updated along with clearing out any failed pull attempts
+ for those old extremities.
+
+ We check that we correctly cleared failed pull attempts of the
+ pulled event.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled"},
+ }
+ ),
+ room_version,
+ )
+
+ # Fake the "pulled" event failing to backfill once so we can test
+ # if it's cleared out later on.
+ self.get_success(
+ main_store.record_event_failed_pull_attempt(
+ pulled_event.room_id, pulled_event.event_id, "fake cause"
+ )
+ )
+ # Make sure we have a failed pull attempt recorded for the pulled event
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts, 1)
+
+ # The function under test: try to process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=True
+ )
+ )
+
+ # Make sure the failed pull attempts for the pulled event are cleared
+ backfill_num_attempts = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ allow_none=True,
+ )
+ )
+ self.assertIsNone(backfill_num_attempts)
+
+ # And as a sanity check, make sure the "pulled" event was persisted.
+ persisted = self.get_success(
+ main_store.get_event(pulled_event.event_id, allow_none=True)
+ )
+ self.assertIsNotNone(persisted, "pulled event was not persisted at all")
+
+ def test_backfill_signature_failure_does_not_fetch_same_prev_event_later(
+ self,
+ ) -> None:
+ """
+ Test to make sure we backoff and don't try to fetch a missing prev_event when we
+ already know it has a invalid signature from checking the signatures of all of
+ the events in the backfill response.
+ """
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Allow the remote user to send state events
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"events_default": 0, "state_default": 0},
+ tok=tok,
+ )
+
+ # Add the remote user to the room
+ member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+
+ auth_event_ids = [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ member_event.event_id,
+ ]
+
+ # We purposely don't run `add_hashes_and_signatures_from_other_server`
+ # over this because we want the signature check to fail.
+ pulled_event_without_signatures = make_event_from_dict(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [member_event.event_id],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event_without_signatures"},
+ },
+ room_version,
+ )
+
+ # Create a regular event that should pass except for the
+ # `pulled_event_without_signatures` in the `prev_event`.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "test_regular_type",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ member_event.event_id,
+ pulled_event_without_signatures.event_id,
+ ],
+ "auth_events": auth_event_ids,
+ "origin_server_ts": 1,
+ "depth": 12,
+ "content": {"body": "pulled_event"},
+ }
+ ),
+ room_version,
+ )
+
+ # We expect an outbound request to /backfill, so stub that out
+ self.mock_federation_transport_client.backfill.return_value = make_awaitable(
+ {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
+ )
+
+ # Keep track of the count and make sure we don't make any of these requests
+ event_endpoint_requested_count = 0
+ room_state_ids_endpoint_requested_count = 0
+ room_state_endpoint_requested_count = 0
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> None:
+ nonlocal event_endpoint_requested_count
+ event_endpoint_requested_count += 1
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_ids_endpoint_requested_count
+ room_state_ids_endpoint_requested_count += 1
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> None:
+ nonlocal room_state_endpoint_requested_count
+ room_state_endpoint_requested_count += 1
+
+ # We don't expect an outbound request to `/event`, `/state_ids`, or `/state` in
+ # the happy path but if the logic is sneaking around what we expect, stub that
+ # out so we can detect that failure
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ # The function under test: try to backfill and process the pulled event
+ with LoggingContext("test"):
+ self.get_success(
+ self.hs.get_federation_event_handler().backfill(
+ self.OTHER_SERVER_NAME,
+ room_id,
+ limit=1,
+ extremities=["$some_extremity"],
+ )
+ )
+
+ if event_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /event in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_ids_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state_ids in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ if room_state_endpoint_requested_count > 0:
+ self.fail(
+ "We don't expect an outbound request to /state in the happy path but if "
+ "the logic is sneaking around what we expect, make sure to fail the test. "
+ "We don't expect it because the signature failure should cause us to backoff "
+ "and not asking about pulled_event_without_signatures="
+ f"{pulled_event_without_signatures.event_id} again"
+ )
+
+ # Make sure we only recorded a single failure which corresponds to the signature
+ # failure initially in `_check_sigs_and_hash_for_pulled_events_and_fetch` before
+ # we process all of the pulled events.
+ backfill_num_attempts_for_event_without_signatures = self.get_success(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event_without_signatures.event_id},
+ retcol="num_attempts",
+ )
+ )
+ self.assertEqual(backfill_num_attempts_for_event_without_signatures, 1)
+
+ # And make sure we didn't record a failure for the event that has the missing
+ # prev_event because we don't want to cause a cascade of failures. Not being
+ # able to fetch the `prev_events` just means we won't be able to de-outlier the
+ # pulled event. But we can still use an `outlier` in the state/auth chain for
+ # another event. So we shouldn't stop a downstream event from trying to pull it.
+ self.get_failure(
+ main_store.db_pool.simple_select_one_onecol(
+ table="event_failed_pull_attempts",
+ keyvalues={"event_id": pulled_event.event_id},
+ retcol="num_attempts",
+ ),
+ # StoreError: 404: No row found
+ StoreError,
+ )
+
+ def test_process_pulled_event_with_rejected_missing_state(self) -> None:
+ """Ensure that we correctly handle pulled events with missing state containing a
+ rejected state event
+
+ In this test, we pretend we are processing a "pulled" event (eg, via backfill
+ or get_missing_events). The pulled event has a prev_event we haven't previously
+ seen, so the server requests the state at that prev_event. We expect the server
+ to make a /state request.
+
+ We simulate a remote server whose /state includes a rejected kick event for a
+ local user. Notably, the kick event is rejected only because it cites a rejected
+ auth event and would otherwise be accepted based on the room state. During state
+ resolution, we re-run auth and can potentially introduce such rejected events
+ into the state if we are not careful.
+
+ We check that the pulled event is correctly persisted, and that the state
+ afterwards does not include the rejected kick.
+ """
+ # The DAG we are testing looks like:
+ #
+ # ...
+ # |
+ # v
+ # remote admin user joins
+ # | |
+ # +-------+ +-------+
+ # | |
+ # | rejected power levels
+ # | from remote server
+ # | |
+ # | v
+ # | rejected kick of local user
+ # v from remote server
+ # new power levels |
+ # | v
+ # | missing event
+ # | from remote server
+ # | |
+ # +-------+ +-------+
+ # | |
+ # v v
+ # pulled event
+ # from remote server
+ #
+ # (arrows are in the opposite direction to prev_events.)
+
+ OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+ main_store = self.hs.get_datastores().main
+
+ # Create the room.
+ kermit_user_id = self.register_user("kermit", "test")
+ kermit_tok = self.login("kermit", "test")
+ room_id = self.helper.create_room_as(
+ room_creator=kermit_user_id, tok=kermit_tok
+ )
+ room_version = self.get_success(main_store.get_room_version(room_id))
+
+ # Add another local user to the room. This user is going to be kicked in a
+ # rejected event.
+ bert_user_id = self.register_user("bert", "test")
+ bert_tok = self.login("bert", "test")
+ self.helper.join(room_id, user=bert_user_id, tok=bert_tok)
+
+ # Allow the remote user to kick bert.
+ # The remote user is going to send a rejected power levels event later on and we
+ # need state resolution to order it before another power levels event kermit is
+ # going to send later on. Hence we give both users the same power level, so that
+ # ties are broken by `origin_server_ts`.
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"users": {kermit_user_id: 100, OTHER_USER: 100}},
+ tok=kermit_tok,
+ )
+
+ # Add the remote user to the room.
+ other_member_event = self.get_success(
+ event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
+ )
+
+ initial_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+ create_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.create", "")])
+ )
+ bert_member_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.member", bert_user_id)])
+ )
+ power_levels_event = self.get_success(
+ main_store.get_event(initial_state_map[("m.room.power_levels", "")])
+ )
+
+ # We now need a rejected state event that will fail
+ # `check_state_independent_auth_rules` but pass
+ # `check_state_dependent_auth_rules`.
+
+ # First, we create a power levels event that we pretend the remote server has
+ # accepted, but the local homeserver will reject.
+ next_depth = 100
+ next_timestamp = other_member_event.origin_server_ts + 100
+ rejected_power_levels_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.power_levels",
+ "state_key": "",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [other_member_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ # The event will be rejected because of the duplicated auth
+ # event.
+ other_member_event.event_id,
+ other_member_event.event_id,
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": power_levels_event.content,
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ with LoggingContext("send_rejected_power_levels_event"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME,
+ rejected_power_levels_event,
+ backfilled=False,
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ main_store.get_rejection_reason(
+ rejected_power_levels_event.event_id
+ )
+ ),
+ "auth_error",
+ )
+
+ # Then we create a kick event for a local user that cites the rejected power
+ # levels event in its auth events. The kick event will be rejected solely
+ # because of the rejected auth event and would otherwise be accepted.
+ rejected_kick_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.member",
+ "state_key": bert_user_id,
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [rejected_power_levels_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ rejected_power_levels_event.event_id,
+ initial_state_map[("m.room.member", bert_user_id)],
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"membership": "leave"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # The kick event must fail the state-independent auth rules, but pass the
+ # state-dependent auth rules, so that it has a chance of making it through state
+ # resolution.
+ self.get_failure(
+ check_state_independent_auth_rules(main_store, rejected_kick_event),
+ AuthError,
+ )
+ check_state_dependent_auth_rules(
+ rejected_kick_event,
+ [create_event, power_levels_event, other_member_event, bert_member_event],
+ )
+
+ # The kick event must also win over the original member event during state
+ # resolution.
+ self.assertEqual(
+ self.get_success(
+ _mainline_sort(
+ self.clock,
+ room_id,
+ event_ids=[
+ bert_member_event.event_id,
+ rejected_kick_event.event_id,
+ ],
+ resolved_power_event_id=power_levels_event.event_id,
+ event_map={
+ bert_member_event.event_id: bert_member_event,
+ rejected_kick_event.event_id: rejected_kick_event,
+ },
+ state_res_store=main_store,
+ )
+ ),
+ [bert_member_event.event_id, rejected_kick_event.event_id],
+ "The rejected kick event will not be applied after bert's join event "
+ "during state resolution. The test setup is incorrect.",
+ )
+
+ with LoggingContext("send_rejected_kick_event"):
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ main_store.get_rejection_reason(rejected_kick_event.event_id)
+ ),
+ "auth_error",
+ )
+
+ # We need another power levels event which will win over the rejected one during
+ # state resolution, otherwise we hit other issues where we end up with rejected
+ # a power levels event during state resolution.
+ self.reactor.advance(100) # ensure the `origin_server_ts` is larger
+ new_power_levels_event = self.get_success(
+ main_store.get_event(
+ self.helper.send_state(
+ room_id,
+ "m.room.power_levels",
+ {"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}},
+ tok=kermit_tok,
+ )["event_id"]
+ )
+ )
+ self.assertEqual(
+ self.get_success(
+ _reverse_topological_power_sort(
+ self.clock,
+ room_id,
+ event_ids=[
+ new_power_levels_event.event_id,
+ rejected_power_levels_event.event_id,
+ ],
+ event_map={},
+ state_res_store=main_store,
+ full_conflicted_set=set(),
+ )
+ ),
+ [rejected_power_levels_event.event_id, new_power_levels_event.event_id],
+ "The power levels events will not have the desired ordering during state "
+ "resolution. The test setup is incorrect.",
+ )
+
+ # Create a missing event, so that the local homeserver has to do a `/state` or
+ # `/state_ids` request to pull state from the remote homeserver.
+ missing_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [rejected_kick_event.event_id],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ initial_state_map[("m.room.power_levels", "")],
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"msgtype": "m.text", "body": "foo"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # The pulled event has two prev events, one of which is missing. We will make a
+ # `/state` or `/state_ids` request to the remote homeserver to ask it for the
+ # state before the missing prev event.
+ pulled_event = make_event_from_dict(
+ self.add_hashes_and_signatures_from_other_server(
+ {
+ "type": "m.room.message",
+ "room_id": room_id,
+ "sender": OTHER_USER,
+ "prev_events": [
+ new_power_levels_event.event_id,
+ missing_event.event_id,
+ ],
+ "auth_events": [
+ initial_state_map[("m.room.create", "")],
+ new_power_levels_event.event_id,
+ initial_state_map[("m.room.member", OTHER_USER)],
+ ],
+ "origin_server_ts": next_timestamp,
+ "depth": next_depth,
+ "content": {"msgtype": "m.text", "body": "bar"},
+ }
+ ),
+ room_version,
+ )
+ next_depth += 1
+ next_timestamp += 100
+
+ # Prepare the response for the `/state` or `/state_ids` request.
+ # The remote server believes bert has been kicked, while the local server does
+ # not.
+ state_before_missing_event = self.get_success(
+ main_store.get_events_as_list(initial_state_map.values())
+ )
+ state_before_missing_event = [
+ event
+ for event in state_before_missing_event
+ if event.event_id != bert_member_event.event_id
+ ]
+ state_before_missing_event.append(rejected_kick_event)
+
+ # We have to bump the clock a bit, to keep the retry logic in
+ # `FederationClient.get_pdu` happy
+ self.reactor.advance(60000)
+ with LoggingContext("send_pulled_event"):
+
+ async def get_event(
+ destination: str, event_id: str, timeout: Optional[int] = None
+ ) -> JsonDict:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return {"pdus": [missing_event.get_pdu_json()]}
+
+ async def get_room_state_ids(
+ destination: str, room_id: str, event_id: str
+ ) -> JsonDict:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return {
+ "pdu_ids": [event.event_id for event in state_before_missing_event],
+ "auth_chain_ids": [],
+ }
+
+ async def get_room_state(
+ room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> StateRequestResponse:
+ self.assertEqual(destination, self.OTHER_SERVER_NAME)
+ self.assertEqual(event_id, missing_event.event_id)
+ return StateRequestResponse(
+ state=state_before_missing_event,
+ auth_events=[],
+ )
+
+ self.mock_federation_transport_client.get_event.side_effect = get_event
+ self.mock_federation_transport_client.get_room_state_ids.side_effect = (
+ get_room_state_ids
+ )
+ self.mock_federation_transport_client.get_room_state.side_effect = (
+ get_room_state
+ )
+
+ self.get_success(
+ self.hs.get_federation_event_handler()._process_pulled_event(
+ self.OTHER_SERVER_NAME, pulled_event, backfilled=False
+ )
+ )
+ self.assertIsNone(
+ self.get_success(
+ main_store.get_rejection_reason(pulled_event.event_id)
+ ),
+ "Pulled event was unexpectedly rejected, likely due to a problem with "
+ "the test setup.",
+ )
+ self.assertEqual(
+ {pulled_event.event_id},
+ self.get_success(
+ main_store.have_events_in_timeline([pulled_event.event_id])
+ ),
+ "Pulled event was not persisted, likely due to a problem with the test "
+ "setup.",
+ )
+
+ # We must not accept rejected events into the room state, so we expect bert
+ # to not be kicked, even if the remote server believes so.
+ new_state_map = self.get_success(
+ main_store.get_partial_current_state_ids(room_id)
+ )
+ self.assertEqual(
+ new_state_map[("m.room.member", bert_user_id)],
+ bert_member_event.event_id,
+ "Rejected kick event unexpectedly became part of room state.",
+ )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 44da96c792..99384837d0 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
event1, context = self._create_duplicate_event(txn_id)
ret_event1 = self.get_success(
- self.handler.handle_new_client_event(self.requester, event1, context)
+ self.handler.handle_new_client_event(
+ self.requester,
+ events_and_context=[(event1, context)],
+ )
)
stream_id1 = ret_event1.internal_metadata.stream_ordering
@@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
ret_event2 = self.get_success(
- self.handler.handle_new_client_event(self.requester, event2, context)
+ self.handler.handle_new_client_event(
+ self.requester,
+ events_and_context=[(event2, context)],
+ )
)
stream_id2 = ret_event2.internal_metadata.stream_ordering
@@ -314,4 +320,4 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
- self.assertEqual(int(channel.result["code"]), 403)
+ self.assertEqual(channel.code, 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
from unittest.mock import ANY, Mock, patch
from urllib.parse import parse_qs, urlparse
@@ -22,12 +21,15 @@ import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
# config for common cases
DEFAULT_CONFIG = {
"enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
EXPLICIT_ENDPOINT_CONFIG = {
**DEFAULT_CONFIG,
"discover": False,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
+ "authorization_endpoint": ISSUER + "authorize",
+ "token_endpoint": ISSUER + "token",
+ "jwks_uri": ISSUER + "jwks",
}
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-async def get_json(url: str) -> JsonDict:
- # Mock get_json calls to handle jwks & oidc discovery endpoints
- if url == WELL_KNOWN:
- # Minimal discovery document, as defined in OpenID.Discovery
- # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
- return {
- "issuer": ISSUER,
- "authorization_endpoint": AUTHORIZATION_ENDPOINT,
- "token_endpoint": TOKEN_ENDPOINT,
- "jwks_uri": JWKS_URI,
- "userinfo_endpoint": USERINFO_ENDPOINT,
- "response_types_supported": ["code"],
- "subject_types_supported": ["public"],
- "id_token_signing_alg_values_supported": ["RS256"],
- }
- elif url == JWKS_URI:
- return {"keys": []}
-
- return {}
-
-
def _key_file_path() -> str:
"""path to a file containing the private half of a test key"""
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
return config
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = b"Synapse Test"
+ self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
- hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+ hs = self.setup_test_homeserver()
+ self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+ self.hs_patcher.start()
self.handler = hs.get_oidc_handler()
self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
+ auth_handler = hs.get_auth_handler()
+ # Mock the complete SSO login method.
+ self.complete_sso_login = simple_async_mock()
+ auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
+
return hs
+ def tearDown(self) -> None:
+ self.hs_patcher.stop()
+ return super().tearDown()
+
+ def reset_mocks(self):
+ """Reset all the Mocks."""
+ self.fake_server.reset_mocks()
+ self.render_error.reset_mock()
+ self.complete_sso_login.reset_mock()
+
def metadata_edit(self, values):
"""Modify the result that will be returned by the well-known query"""
- async def patched_get_json(uri):
- res = await get_json(uri)
- if uri == WELL_KNOWN:
- res.update(values)
- return res
+ metadata = self.fake_server.get_metadata()
+ metadata.update(values)
+ return patch.object(self.fake_server, "get_metadata", return_value=metadata)
- return patch.object(self.http_client, "get_json", patched_get_json)
+ def start_authorization(
+ self,
+ userinfo: dict,
+ client_redirect_url: str = "http://client/redirect",
+ scope: str = "openid",
+ with_sid: bool = False,
+ ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+ """Start an authorization request, and get the callback request back."""
+ nonce = random_string(10)
+ state = random_string(10)
+
+ code, grant = self.fake_server.start_authorization(
+ userinfo=userinfo,
+ scope=scope,
+ client_id=self.provider._client_auth.client_id,
+ redirect_uri=self.provider._callback_url,
+ nonce=nonce,
+ with_sid=with_sid,
+ )
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ return _build_callback_request(code, state, session), grant
def assertRenderedError(self, error, error_description=None):
self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
metadata = self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+ self.fake_server.get_metadata_handler.assert_called_once()
- self.assertEqual(metadata.issuer, ISSUER)
- self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
- self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
- self.assertEqual(metadata.jwks_uri, JWKS_URI)
- # FIXME: it seems like authlib does not have that defined in its metadata models
- # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+ self.assertEqual(metadata.issuer, self.fake_server.issuer)
+ self.assertEqual(
+ metadata.authorization_endpoint,
+ self.fake_server.authorization_endpoint,
+ )
+ self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+ self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+ # It seems like authlib does not have that defined in its metadata models
+ self.assertEqual(
+ metadata.get("userinfo_endpoint"),
+ self.fake_server.userinfo_endpoint,
+ )
# subsequent calls should be cached
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
@override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
def test_no_discovery(self) -> None:
"""When discovery is disabled, it should not try to load from discovery document."""
self.get_success(self.provider.load_metadata())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_metadata_handler.assert_not_called()
- @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+ @override_config({"oidc_config": DEFAULT_CONFIG})
def test_load_jwks(self) -> None:
"""JWKS loading is done once (then cached) if used."""
jwks = self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
- self.assertEqual(jwks, {"keys": []})
+ self.fake_server.get_jwks_handler.assert_called_once()
+ self.assertEqual(jwks, self.fake_server.get_jwks())
# subsequent calls should be cached…
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks())
- self.http_client.get_json.assert_not_called()
+ self.fake_server.get_jwks_handler.assert_not_called()
# …unless forced
- self.http_client.reset_mock()
+ self.reset_mocks()
self.get_success(self.provider.load_jwks(force=True))
- self.http_client.get_json.assert_called_once_with(JWKS_URI)
+ self.fake_server.get_jwks_handler.assert_called_once()
- # Throw if the JWKS uri is missing
- original = self.provider.load_metadata
-
- async def patched_load_metadata():
- m = (await original()).copy()
- m.update({"jwks_uri": None})
- return m
-
- with patch.object(self.provider, "load_metadata", patched_load_metadata):
+ with self.metadata_edit({"jwks_uri": None}):
+ # If we don't do this, the load_metadata call will throw because of the
+ # missing jwks_uri
+ self.provider._user_profile_method = "userinfo_endpoint"
+ self.get_success(self.provider.load_metadata(force=True))
self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
@override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.provider.handle_redirect_request(req, b"http://client/redirect")
)
)
- auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+ auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
self.assertEqual(url.scheme, auth_endpoint.scheme)
self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
username = "bar"
userinfo = {
"sub": "foo",
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
- code = "code"
- state = "state"
- nonce = "nonce"
client_redirect_url = "http://client/redirect"
- ip_address = "10.0.0.1"
- session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
- request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+ request, _ = self.start_authorization(
+ userinfo, client_redirect_url=client_redirect_url
+ )
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
client_redirect_url,
None,
new_user=True,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_not_called()
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
+ request, _ = self.start_authorization(userinfo)
with patch.object(
self.provider,
"_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- auth_handler.complete_sso_login.reset_mock()
- self.provider._exchange_code.reset_mock()
- self.provider._parse_id_token.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ self.reset_mocks()
# With userinfo fetching
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- }
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
+ # Without the "openid" scope, the FakeProvider does not generate an id_token
+ request, _ = self.start_authorization(userinfo, scope="")
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_not_called()
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
+ self.reset_mocks()
+
# With an ID token, userinfo fetching and sid in the ID token
self.provider._user_profile_method = "userinfo_endpoint"
- token = {
- "type": "bearer",
- "access_token": "access_token",
- "id_token": "id_token",
- }
- id_token = {
- "sid": "abcdefgh",
- }
- self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment]
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- auth_handler.complete_sso_login.reset_mock()
- self.provider._fetch_userinfo.reset_mock()
+ request, grant = self.start_authorization(userinfo, with_sid=True)
+ self.assertIsNotNone(grant.sid)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
expected_user_id,
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
None,
new_user=False,
- auth_provider_session_id=id_token["sid"],
+ auth_provider_session_id=grant.sid,
)
- self.provider._exchange_code.assert_called_once_with(code)
- self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.provider._fetch_userinfo.assert_called_once_with(token)
+ self.fake_server.post_token_handler.assert_called_once()
+ self.fake_server.get_userinfo_handler.assert_called_once()
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment]
- self.get_success(self.handler.handle_oidc_callback(request))
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(userinfo=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
- # Handle code exchange failure
- from synapse.handlers.oidc import OidcError
-
- self.provider._exchange_code = simple_async_mock( # type: ignore[assignment]
- raises=OidcError("invalid_request")
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("invalid_request")
+ request, _ = self.start_authorization(userinfo)
+ with self.fake_server.buggy_endpoint(token=True):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("server_error")
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code(self) -> None:
"""Code exchange behaves correctly and handles various error scenarios."""
- token = {"type": "bearer"}
- token_json = json.dumps(token).encode("utf-8")
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(ret, token)
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
args = parse_qs(kwargs["data"].decode("utf-8"))
self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
# Test error handling
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad Request",
- body=b'{"error": "foo", "error_description": "bar"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={"error": "foo", "error_description": "bar"}
)
from synapse.handlers.oidc import OidcError
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b"Not JSON",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse(
+ code=500, body=b"Not JSON"
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=500,
- phrase=b"Internal Server Error",
- body=b'{"error": "internal_server_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=500, payload={"error": "internal_server_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=400,
- phrase=b"Bad request",
- body=b"{}",
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=400, payload={}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200,
- phrase=b"OK",
- body=b'{"error": "some_error"}',
- )
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ code=200, payload={"error": "some_error"}
)
exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Test that code exchange works with a JWK client secret."""
from authlib.jose import jwt
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# the client secret provided to the should be a jwt which can be checked with
# the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_exchange_code_no_auth(self) -> None:
"""Test that code exchange works with no client secret."""
- token = {"type": "bearer"}
- self.http_client.request = simple_async_mock(
- return_value=FakeResponse(
- code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
- )
+ token = {
+ "type": "Bearer",
+ "access_token": "aabbcc",
+ }
+
+ self.fake_server.post_token_handler.side_effect = None
+ self.fake_server.post_token_handler.return_value = FakeResponse.json(
+ payload=token
)
code = "code"
ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(ret, token)
# the request should have hit the token endpoint
- kwargs = self.http_client.request.call_args[1]
+ kwargs = self.fake_server.request.call_args[1]
self.assertEqual(kwargs["method"], "POST")
- self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+ self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
# check the POSTed data
args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""
Login while using a mapping provider that implements get_extra_attributes.
"""
- token = {
- "type": "bearer",
- "id_token": "id_token",
- "access_token": "access_token",
- }
userinfo = {
"sub": "foo",
"username": "foo",
"phone": "1234567",
}
- self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment]
- self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
- state = "state"
- client_redirect_url = "http://client/redirect"
- session = self._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- )
- request = _build_callback_request("code", state, session)
-
+ request, _ = self.start_authorization(userinfo)
self.get_success(self.handler.handle_oidc_callback(request))
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@foo:test",
- "oidc",
+ self.provider.idp_id,
request,
- client_redirect_url,
+ ANY,
{"phone": "1234567"},
new_user=True,
auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_user(self) -> None:
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
userinfo: dict = {
"sub": "test_user",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@test_user_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Test if the mxid is already taken
store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error",
"Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Subsequent calls should map to the same mxid.
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
user.to_string(),
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
args = self.assertRenderedError("mapping_error")
self.assertTrue(
args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_called_once_with(
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
@override_config({"oidc_config": DEFAULT_CONFIG})
def test_map_userinfo_to_invalid_localpart(self) -> None:
"""If the mapping provider generates an invalid localpart it should be rejected."""
- self.get_success(
- _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
- )
+ userinfo = {"sub": "test2", "username": "föö"}
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self) -> None:
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
store = self.hs.get_datastores().main
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# test_user is already taken, so test_user1 gets registered instead.
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@test_user1:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
auth_provider_session_id=None,
)
- auth_handler.complete_sso_login.reset_mock()
+ self.reset_mocks()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
self.assertRenderedError(
"mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error", "localpart is invalid: ")
@override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements(self) -> None:
"""The required attributes must be met from the OIDC userinfo response."""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
-
# userinfo lacking "test": "foobar" attribute should fail.
userinfo = {
"sub": "tester",
"username": "tester",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": "foobar" attribute should succeed.
userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": "foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_attribute_requirements_contains(self) -> None:
"""Test that auth succeeds if userinfo attribute CONTAINS required value"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
userinfo = {
"sub": "tester",
"username": "tester",
"test": ["foobar", "foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
# check that the auth handler got called as expected
- auth_handler.complete_sso_login.assert_called_once_with(
+ self.complete_sso_login.assert_called_once_with(
"@tester:test",
- "oidc",
- ANY,
+ self.provider.idp_id,
+ request,
ANY,
None,
new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
Test that auth fails if attributes exist but don't match,
or are non-string values.
"""
- auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock()
# userinfo with "test": "not_foobar" attribute should fail
userinfo: dict = {
"sub": "tester",
"username": "tester",
"test": "not_foobar",
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": ["foo", "bar"] attribute should fail
userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": ["foo", "bar"],
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": False attribute should fail
# this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": False,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": None attribute should fail
# a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": None,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 1 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 1,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
# userinfo with "test": 3.14 attribute should fail
# this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "tester",
"test": 3.14,
}
- self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
- auth_handler.complete_sso_login.assert_not_called()
+ request, _ = self.start_authorization(userinfo)
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.complete_sso_login.assert_not_called()
def _generate_oidc_session_token(
self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return self.handler._macaroon_generator.generate_oidc_session_token(
state=state,
session_data=OidcSessionData(
- idp_id="oidc",
+ idp_id=self.provider.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
-async def _make_callback_with_userinfo(
- hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
- """Mock up an OIDC callback with the given userinfo dict
-
- We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
- and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
- Args:
- hs: the HomeServer impl to send the callback to.
- userinfo: the OIDC userinfo dict
- client_redirect_url: the URL to redirect to on success.
- """
-
- handler = hs.get_oidc_handler()
- provider = handler._providers["oidc"]
- provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment]
- provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
- provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment]
-
- state = "state"
- session = handler._macaroon_generator.generate_oidc_session_token(
- state=state,
- session_data=OidcSessionData(
- idp_id="oidc",
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id="",
- ),
- )
- request = _build_callback_request("code", state, session)
-
- await handler.handle_oidc_callback(request)
-
-
def _build_callback_request(
code: str,
state: str,
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 4c62449c89..75934b1707 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -21,7 +21,6 @@ from unittest.mock import Mock
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
-from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
from synapse.rest.client import account, devices, login, logout, register
from synapse.types import JsonDict, UserID
@@ -167,16 +166,6 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver()
- # Load the modules into the homeserver
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
- load_legacy_password_auth_providers(hs)
-
- return hs
-
@override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
def test_password_only_auth_progiver_login_legacy(self):
self.password_only_auth_provider_login_test_body()
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index c96dc6caf2..c5981ff965 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -15,6 +15,7 @@
from typing import Optional
from unittest.mock import Mock, call
+from parameterized import parameterized
from signedjson.key import generate_signing_key
from synapse.api.constants import EventTypes, Membership, PresenceState
@@ -37,6 +38,7 @@ from synapse.rest.client import room
from synapse.types import UserID, get_domain_from_id
from tests import unittest
+from tests.replication._base import BaseMultiWorkerStreamTestCase
class PresenceUpdateTestCase(unittest.HomeserverTestCase):
@@ -505,7 +507,7 @@ class PresenceTimeoutTestCase(unittest.TestCase):
self.assertEqual(state, new_state)
-class PresenceHandlerTestCase(unittest.HomeserverTestCase):
+class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
def prepare(self, reactor, clock, hs):
self.presence_handler = hs.get_presence_handler()
self.clock = hs.get_clock()
@@ -716,20 +718,47 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase):
# our status message should be the same as it was before
self.assertEqual(state.status_msg, status_msg)
- def test_set_presence_from_syncing_keeps_busy(self):
- """Test that presence set by syncing doesn't affect busy status"""
- # while this isn't the default
- self.presence_handler._busy_presence_enabled = True
+ @parameterized.expand([(False,), (True,)])
+ @unittest.override_config(
+ {
+ "experimental_features": {
+ "msc3026_enabled": True,
+ },
+ }
+ )
+ def test_set_presence_from_syncing_keeps_busy(self, test_with_workers: bool):
+ """Test that presence set by syncing doesn't affect busy status
+ Args:
+ test_with_workers: If True, check the presence state of the user by calling
+ /sync against a worker, rather than the main process.
+ """
user_id = "@test:server"
status_msg = "I'm busy!"
+ # By default, we call /sync against the main process.
+ worker_to_sync_against = self.hs
+ if test_with_workers:
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+ )
+
+ # Set presence to BUSY
self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg)
+ # Perform a sync with a presence state other than busy. This should NOT change
+ # our presence status; we only change from busy if we explicitly set it via
+ # /presence/*.
self.get_success(
- self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE)
+ worker_to_sync_against.get_presence_handler().user_syncing(
+ user_id, True, PresenceState.ONLINE
+ )
)
+ # Check against the main process that the user's presence did not change.
state = self.get_success(
self.presence_handler.get_state(UserID.from_string(user_id))
)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index f88c725a42..675aa023ac 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -14,6 +14,8 @@
from typing import Any, Awaitable, Callable, Dict
from unittest.mock import Mock
+from parameterized import parameterized
+
from twisted.test.proto_helpers import MemoryReactor
import synapse.types
@@ -327,6 +329,53 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
self.assertFalse(res)
+ @unittest.override_config(
+ {"server_name": "test:8888", "allowed_avatar_mimetypes": ["image/png"]}
+ )
+ def test_avatar_constraint_on_local_server_with_port(self):
+ """Test that avatar metadata is correctly fetched when the media is on a local
+ server and the server has an explicit port.
+
+ (This was previously a bug)
+ """
+ local_server_name = self.hs.config.server.server_name
+ media_id = "local"
+ local_mxc = f"mxc://{local_server_name}/{media_id}"
+
+ # mock up the existence of the avatar file
+ self._setup_local_files({media_id: {"mimetype": "image/png"}})
+
+ # and now check that check_avatar_size_and_mime_type is happy
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(local_mxc))
+ )
+
+ @parameterized.expand([("remote",), ("remote:1234",)])
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]})
+ def test_check_avatar_on_remote_server(self, remote_server_name: str) -> None:
+ """Test that avatar metadata is correctly fetched from a remote server"""
+ media_id = "remote"
+ remote_mxc = f"mxc://{remote_server_name}/{media_id}"
+
+ # if the media is remote, check_avatar_size_and_mime_type just checks the
+ # media cache, so we don't need to instantiate a real remote server. It is
+ # sufficient to poke an entry into the db.
+ self.get_success(
+ self.hs.get_datastores().main.store_cached_remote_media(
+ media_id=media_id,
+ media_type="image/png",
+ media_length=50,
+ origin=remote_server_name,
+ time_now_ms=self.clock.time_msec(),
+ upload_name=None,
+ filesystem_id="xyz",
+ )
+ )
+
+ self.assertTrue(
+ self.get_success(self.handler.check_avatar_size_and_mime_type(remote_mxc))
+ )
+
def _setup_local_files(self, names_and_props: Dict[str, Dict[str, Any]]):
"""Stores metadata about files in the database.
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index a95868b5c0..b55238650c 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -25,7 +25,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.event_source = hs.get_event_sources().sources.receipt
- def test_filters_out_private_receipt(self):
+ def test_filters_out_private_receipt(self) -> None:
self._test_filters_private(
[
{
@@ -45,7 +45,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[],
)
- def test_filters_out_private_receipt_and_ignores_rest(self):
+ def test_filters_out_private_receipt_and_ignores_rest(self) -> None:
self._test_filters_private(
[
{
@@ -84,7 +84,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self):
+ def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -125,7 +127,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_empty_event(self):
+ def test_handles_empty_event(self) -> None:
self._test_filters_private(
[
{
@@ -160,7 +162,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self):
+ def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(
+ self,
+ ) -> None:
self._test_filters_private(
[
{
@@ -207,7 +211,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_handles_string_data(self):
+ def test_handles_string_data(self) -> None:
"""
Tests that an invalid shape for read-receipts is handled.
Context: https://github.com/matrix-org/synapse/issues/10603
@@ -242,7 +246,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_leaves_our_private_and_their_public(self):
+ def test_leaves_our_private_and_their_public(self) -> None:
self._test_filters_private(
[
{
@@ -296,7 +300,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
],
)
- def test_we_do_not_mutate(self):
+ def test_we_do_not_mutate(self) -> None:
"""Ensure the input values are not modified."""
events = [
{
@@ -320,7 +324,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
def _test_filters_private(
self, events: List[JsonDict], expected_output: List[JsonDict]
- ):
+ ) -> None:
"""Tests that the _filter_out_private returns the expected output"""
filtered_events = self.event_source.filter_out_private_receipts(
events, "@me:server.org"
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 23f35d5bf5..765df75d91 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -22,7 +22,6 @@ from synapse.api.errors import (
ResourceLimitError,
SynapseError,
)
-from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, RoomID, UserID, create_requester
@@ -144,12 +143,6 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
config=hs_config, federation_client=self.mock_federation_client
)
- load_legacy_spam_checkers(hs)
-
- module_api = hs.get_module_api()
- for module, config in hs.config.modules.loaded_modules:
- module(config=config, api=module_api)
-
return hs
def prepare(self, reactor, clock, hs):
@@ -504,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.handle_new_client_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(
+ requester, events_and_context=[(event, context)]
+ )
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 254e7e4b80..6bbfd5dc84 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -1,4 +1,3 @@
-from http import HTTPStatus
from unittest.mock import Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -7,7 +6,7 @@ import synapse.rest.admin
import synapse.rest.client.login
import synapse.rest.client.room
from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import LimitExceededError
+from synapse.api.errors import LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events import FrozenEventV3
from synapse.federation.federation_client import SendJoinResult
@@ -15,10 +14,14 @@ from synapse.server import HomeServer
from synapse.types import UserID, create_requester
from synapse.util import Clock
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
from tests.test_utils import make_awaitable
-from tests.unittest import FederatingHomeserverTestCase, override_config
+from tests.unittest import (
+ FederatingHomeserverTestCase,
+ HomeserverTestCase,
+ override_config,
+)
class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
@@ -217,7 +220,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
# - trying to remote-join again.
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
synapse.rest.client.login.register_servlets,
@@ -260,7 +263,7 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
f"/_matrix/client/v3/rooms/{self.room_id}/join",
access_token=self.bob_token,
)
- self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
+ self.assertEqual(channel.code, 200, channel.json_body)
# wait for join to arrive over replication
self.replicate()
@@ -288,3 +291,88 @@ class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestC
),
LimitExceededError,
)
+
+
+class RoomMemberMasterHandlerTestCase(HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ synapse.rest.client.login.register_servlets,
+ synapse.rest.client.room.register_servlets,
+ ]
+
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.handler = hs.get_room_member_handler()
+ self.store = hs.get_datastores().main
+
+ # Create two users.
+ self.alice = self.register_user("alice", "pass")
+ self.alice_ID = UserID.from_string(self.alice)
+ self.alice_token = self.login("alice", "pass")
+ self.bob = self.register_user("bob", "pass")
+ self.bob_ID = UserID.from_string(self.bob)
+ self.bob_token = self.login("bob", "pass")
+
+ # Create a room on this homeserver.
+ self.room_id = self.helper.create_room_as(self.alice, tok=self.alice_token)
+
+ def test_leave_and_forget(self) -> None:
+ """Tests that forget a room is successfully. The test is performed with two users,
+ as forgetting by the last user respectively after all users had left the
+ is a special edge case."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ # alice is not the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_leave_and_forget_last_user(self) -> None:
+ """Tests that forget a room is successfully when the last user has left the room."""
+
+ # alice is the last room member that leaves and forgets the room
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has forgotten the room
+ self.assertTrue(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ def test_forget_when_not_left(self) -> None:
+ """Tests that a user cannot not forgets a room that has not left."""
+ self.get_failure(self.handler.forget(self.alice_ID, self.room_id), SynapseError)
+
+ def test_rejoin_forgotten_by_user(self) -> None:
+ """Test that a user that has forgotten a room can do a re-join.
+ The room was not forgotten from the local server.
+ One local user is still member of the room."""
+ self.helper.join(self.room_id, user=self.bob, tok=self.bob_token)
+
+ self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token)
+ self.get_success(self.handler.forget(self.alice_ID, self.room_id))
+ self.assertTrue(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
+
+ # the server has not forgotten the room
+ self.assertFalse(
+ self.get_success(self.store.is_locally_forgotten_room(self.room_id))
+ )
+
+ self.helper.join(self.room_id, user=self.alice, tok=self.alice_token)
+ # TODO: A join to a room does not invalidate the forgotten cache
+ # see https://github.com/matrix-org/synapse/issues/13262
+ self.store.did_forget.invalidate_all()
+ self.assertFalse(
+ self.get_success(self.store.did_forget(self.alice, self.room_id))
+ )
diff --git a/tests/handlers/test_sso.py b/tests/handlers/test_sso.py
new file mode 100644
index 0000000000..137deab138
--- /dev/null
+++ b/tests/handlers/test_sso.py
@@ -0,0 +1,145 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from http import HTTPStatus
+from typing import BinaryIO, Callable, Dict, List, Optional, Tuple
+from unittest.mock import Mock
+
+from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.http_headers import Headers
+
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.client import RawHeaders
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests import unittest
+from tests.test_utils import SMALL_PNG, FakeResponse
+
+
+class TestSSOHandler(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.http_client = Mock(spec=["get_file"])
+ self.http_client.get_file.side_effect = mock_get_file
+ self.http_client.user_agent = b"Synapse Test"
+ hs = self.setup_test_homeserver(
+ proxied_blacklisted_http_client=self.http_client
+ )
+ return hs
+
+ async def test_set_avatar(self) -> None:
+ """Tests successfully setting the avatar of a newly created user"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # Ensure avatar is set on this newly created user,
+ # so no need to compare for the exact image
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertIsNot(profile["avatar_url"], None)
+
+ @unittest.override_config({"max_avatar_size": 1})
+ async def test_set_avatar_too_big_image(self) -> None:
+ """Tests that saving an avatar fails when it is too big"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ @unittest.override_config({"allowed_avatar_mimetypes": ["image/jpeg"]})
+ async def test_set_avatar_incorrect_mime_type(self) -> None:
+ """Tests that saving an avatar fails when its mime type is not allowed"""
+ handler = self.hs.get_sso_handler()
+
+ # any random user works since image check is supposed to fail
+ user_id = "@sso-user:test"
+
+ self.assertFalse(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ async def test_skip_saving_avatar_when_not_changed(self) -> None:
+ """Tests whether saving of avatar correctly skips if the avatar hasn't
+ changed"""
+ handler = self.hs.get_sso_handler()
+
+ # Create a new user to set avatar for
+ reg_handler = self.hs.get_registration_handler()
+ user_id = self.get_success(reg_handler.register_user(approved=True))
+
+ # set avatar for the first time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # get avatar picture for comparison after another attempt
+ profile_handler = self.hs.get_profile_handler()
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ url_to_match = profile["avatar_url"]
+
+ # set same avatar for the second time, should be a success
+ self.assertTrue(
+ self.get_success(handler.set_avatar(user_id, "http://my.server/me.png"))
+ )
+
+ # compare avatar picture's url from previous step
+ profile = self.get_success(profile_handler.get_profile(user_id))
+ self.assertEqual(profile["avatar_url"], url_to_match)
+
+
+async def mock_get_file(
+ url: str,
+ output_stream: BinaryIO,
+ max_size: Optional[int] = None,
+ headers: Optional[RawHeaders] = None,
+ is_allowed_content_type: Optional[Callable[[str], bool]] = None,
+) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
+
+ fake_response = FakeResponse(code=404)
+ if url == "http://my.server/me.png":
+ fake_response = FakeResponse(
+ code=200,
+ headers=Headers(
+ {"Content-Type": ["image/png"], "Content-Length": [str(len(SMALL_PNG))]}
+ ),
+ body=SMALL_PNG,
+ )
+
+ if max_size is not None and max_size < len(SMALL_PNG):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ "Requested file is too large > %r bytes" % (max_size,),
+ Codes.TOO_LARGE,
+ )
+
+ if is_allowed_content_type and not is_allowed_content_type("image/png"):
+ raise SynapseError(
+ HTTPStatus.BAD_GATEWAY,
+ (
+ "Requested file's content type not allowed for this operation: %s"
+ % "image/png"
+ ),
+ )
+
+ output_stream.write(fake_response.body)
+
+ return len(SMALL_PNG), {b"Content-Type": [b"image/png"]}, "", 200
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e3f38fbcc5..ab5c101eb7 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Blow away caches (supported room versions can only change due to a restart).
self.store.get_rooms_for_user_with_stream_ordering.invalidate_all()
+ self.store.get_rooms_for_user.invalidate_all()
self.get_success(self.store._get_event_cache.clear())
self.store._event_ref.clear()
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 7af1333126..9c821b3042 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID, create_requester
+from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock
from tests import unittest
@@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = []
- async def check_user_in_room(room_id: str, user_id: str) -> None:
- if user_id not in [u.to_string() for u in self.room_members]:
+ async def check_user_in_room(room_id: str, requester: Requester) -> None:
+ if requester.user.to_string() not in [
+ u.to_string() for u in self.room_members
+ ]:
raise AuthError(401, "User is not in the room")
return None
@@ -127,7 +129,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
- hs.get_event_auth_handler().check_host_in_room = check_host_in_room
+ hs.get_event_auth_handler().is_host_in_room = check_host_in_room
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
@@ -136,6 +138,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
get_current_hosts_in_room
)
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+ get_current_hosts_in_room
+ )
+
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
|