diff --git a/tests/handlers/test_admin.py b/tests/handlers/test_admin.py
index fc37c4328c..5c2b4de1a6 100644
--- a/tests/handlers/test_admin.py
+++ b/tests/handlers/test_admin.py
@@ -35,7 +35,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.admin_handler = hs.get_handlers().admin_handler
+ self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
self.token1 = self.login("user1", "password")
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 2a0b7c1b56..53763cd0f9 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -18,6 +18,7 @@ from mock import Mock
from twisted.internet import defer
from synapse.handlers.appservice import ApplicationServicesHandler
+from synapse.types import RoomStreamToken
from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -41,7 +42,6 @@ class AppServiceHandlerTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
self.handler = ApplicationServicesHandler(hs)
- @defer.inlineCallbacks
def test_notify_interested_services(self):
interested_service = self._mkservice(is_interested=True)
services = [
@@ -61,12 +61,12 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_scheduler.submit_event_for_as.assert_called_once_with(
interested_service, event
)
- @defer.inlineCallbacks
def test_query_user_exists_unknown_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -80,10 +80,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
- @defer.inlineCallbacks
def test_query_user_exists_known_user(self):
user_id = "@someone:anywhere"
services = [self._mkservice(is_interested=True)]
@@ -97,7 +98,9 @@ class AppServiceHandlerTestCase(unittest.TestCase):
defer.succeed((0, [event])),
defer.succeed((0, [])),
]
- yield defer.ensureDeferred(self.handler.notify_interested_services(0))
+
+ self.handler.notify_interested_services(RoomStreamToken(None, 0))
+
self.assertFalse(
self.mock_as_api.query_user.called,
"query_user called when it shouldn't have been.",
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 97877c2e42..e24ce81284 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -21,24 +21,17 @@ from twisted.internet import defer
import synapse
import synapse.api.errors
from synapse.api.errors import ResourceLimitError
-from synapse.handlers.auth import AuthHandler
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class AuthHandlers:
- def __init__(self, hs):
- self.auth_handler = AuthHandler(hs)
-
-
class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
- self.hs = yield setup_test_homeserver(self.addCleanup, handlers=None)
- self.hs.handlers = AuthHandlers(self.hs)
- self.auth_handler = self.hs.handlers.auth_handler
+ self.hs = yield setup_test_homeserver(self.addCleanup)
+ self.auth_handler = self.hs.get_auth_handler()
self.macaroon_generator = self.hs.get_macaroon_generator()
# MAU tests
@@ -59,7 +52,7 @@ class AuthTestCase(unittest.TestCase):
self.fail("some_user was not in %s" % macaroon.inspect())
def test_macaroon_caveats(self):
- self.hs.clock.now = 5000
+ self.hs.get_clock().now = 5000
token = self.macaroon_generator.generate_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -85,7 +78,7 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_short_term_login_token_gives_user_id(self):
- self.hs.clock.now = 1000
+ self.hs.get_clock().now = 1000
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = yield defer.ensureDeferred(
@@ -94,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual("a_user", user_id)
# when we advance the clock, the token should be rejected
- self.hs.clock.now = 6000
+ self.hs.get_clock().now = 6000
with self.assertRaises(synapse.api.errors.AuthError):
yield defer.ensureDeferred(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 969d44c787..875aaec2c6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
# Copyright 2018 New Vector Ltd
+# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -224,3 +225,84 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
self.reactor.advance(1000)
+
+
+class DehydrationTestCase(unittest.HomeserverTestCase):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver("server", http_client=None)
+ self.handler = hs.get_device_handler()
+ self.registration = hs.get_registration_handler()
+ self.auth = hs.get_auth()
+ self.store = hs.get_datastore()
+ return hs
+
+ def test_dehydrate_and_rehydrate_device(self):
+ user_id = "@boris:dehydration"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ retrieved_device_id, device_data = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user and dehydrated the device
+ device_id, access_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id, device_id=None, initial_display_name="new device",
+ )
+ )
+
+ # Trying to claim a nonexistent device should throw an error
+ self.get_failure(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id="not the right device ID",
+ ),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # dehydrating the right devices should succeed and change our device ID
+ # to the dehydrated device's ID
+ res = self.get_success(
+ self.handler.rehydrate_device(
+ user_id=user_id,
+ access_token=access_token,
+ device_id=retrieved_device_id,
+ )
+ )
+
+ self.assertEqual(res, {"success": True})
+
+ # make sure that our device ID has changed
+ user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
+
+ self.assertEqual(user_info.device_id, retrieved_device_id)
+
+ # make sure the device has the display name that was set from the login
+ res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
+
+ self.assertEqual(res["display_name"], "new device")
+
+ # make sure that the device ID that we were initially assigned no longer exists
+ self.get_failure(
+ self.handler.get_device(user_id, device_id),
+ synapse.api.errors.NotFoundError,
+ )
+
+ # make sure that there's no device available for dehydrating now
+ ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
+
+ self.assertIsNone(ret)
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index bc0c5aefdc..ee6ef5e6fa 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -48,7 +48,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
federation_registry=self.mock_registry,
)
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.store = hs.get_datastore()
@@ -110,7 +110,7 @@ class TestCreateAlias(unittest.HomeserverTestCase):
]
def prepare(self, reactor, clock, hs):
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -173,7 +173,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -289,7 +289,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
- self.handler = hs.get_handlers().directory_handler
+ self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
# Create user
@@ -412,7 +412,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
def test_allowed(self):
@@ -423,7 +422,6 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
@@ -438,11 +436,10 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.room_list_handler = hs.get_room_list_handler()
- self.directory_handler = hs.get_handlers().directory_handler
+ self.directory_handler = hs.get_directory_handler()
return hs
@@ -452,7 +449,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list is enabled so we should get some results
request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -461,7 +457,6 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
# Room list disabled so we should get no results
request, channel = self.make_request("GET", b"publicRooms")
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
@@ -470,5 +465,4 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
- self.render(request)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 366dcfb670..924f29f051 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -33,13 +33,15 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
super().__init__(*args, **kwargs)
self.hs = None # type: synapse.server.HomeServer
self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler
+ self.store = None # type: synapse.storage.Storage
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, federation_client=mock.Mock()
+ self.addCleanup, federation_client=mock.Mock()
)
self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+ self.store = self.hs.get_datastore()
@defer.inlineCallbacks
def test_query_local_devices_no_devices(self):
@@ -172,6 +174,89 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
@defer.inlineCallbacks
+ def test_fallback_key(self):
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ fallback_key = {"alg1:k1": "key1"}
+ otk = {"alg1:k2": "key2"}
+
+ # we shouldn't have any unused fallback keys yet
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user,
+ device_id,
+ {"org.matrix.msc2732.fallback_keys": fallback_key},
+ )
+ )
+
+ # we should now have an unused alg1 key
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, ["alg1"])
+
+ # claiming an OTK when no OTKs are available should return the fallback
+ # key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # we shouldn't have any unused fallback keys again
+ res = yield defer.ensureDeferred(
+ self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
+ )
+ self.assertEqual(res, [])
+
+ # claiming an OTK again should return the same fallback key
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ # if the user uploads a one-time key, the next claim should fetch the
+ # one-time key, and then go back to the fallback
+ yield defer.ensureDeferred(
+ self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": otk}
+ )
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
+ )
+
+ res = yield defer.ensureDeferred(
+ self.handler.claim_one_time_keys(
+ {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
+ )
+ )
+ self.assertEqual(
+ res,
+ {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
+ )
+
+ @defer.inlineCallbacks
def test_replace_master_key(self):
"""uploading a new signing key should make the old signing key unavailable"""
local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_e2e_room_keys.py b/tests/handlers/test_e2e_room_keys.py
index 7adde9b9de..45f201a399 100644
--- a/tests/handlers/test_e2e_room_keys.py
+++ b/tests/handlers/test_e2e_room_keys.py
@@ -54,7 +54,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
self.hs = yield utils.setup_test_homeserver(
- self.addCleanup, handlers=None, replication_layer=mock.Mock()
+ self.addCleanup, replication_layer=mock.Mock()
)
self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
self.local_user = "@boris:" + self.hs.hostname
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 96fea58673..bf866dacf3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -38,7 +38,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(http_client=None)
- self.handler = hs.get_handlers().federation_handler
+ self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
@@ -59,7 +59,6 @@ class FederationTestCase(unittest.HomeserverTestCase):
)
d = self.handler.on_exchange_third_party_invite_request(
- room_id=room_id,
event_dict={
"type": EventTypes.Member,
"room_id": room_id,
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
new file mode 100644
index 0000000000..af42775815
--- /dev/null
+++ b/tests/handlers/test_message.py
@@ -0,0 +1,212 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+from synapse.types import create_requester
+from synapse.util.stringutils import random_string
+
+from tests import unittest
+
+logger = logging.getLogger(__name__)
+
+
+class EventCreationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.handler = self.hs.get_event_creation_handler()
+ self.persist_event_storage = self.hs.get_storage().persistence
+
+ self.user_id = self.register_user("tester", "foobar")
+ self.access_token = self.login("tester", "foobar")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ self.info = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(self.access_token,)
+ )
+ self.token_id = self.info.token_id
+
+ self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+
+ def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+ """Create a new event with the given transaction ID. All events produced
+ by this method will be considered duplicates.
+ """
+
+ # We create a new event with a random body, as otherwise we'll produce
+ # *exactly* the same event with the same hash, and so same event ID.
+ return self.get_success(
+ self.handler.create_event(
+ self.requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": self.room_id,
+ "sender": self.requester.user.to_string(),
+ "content": {"msgtype": "m.text", "body": random_string(5)},
+ },
+ txn_id=txn_id,
+ )
+ )
+
+ def test_duplicated_txn_id(self):
+ """Test that attempting to handle/persist an event with a transaction ID
+ that has already been persisted correctly returns the old event and does
+ *not* produce duplicate messages.
+ """
+
+ txn_id = "something_suitably_random"
+
+ event1, context = self._create_duplicate_event(txn_id)
+
+ ret_event1 = self.get_success(
+ self.handler.handle_new_client_event(self.requester, event1, context)
+ )
+ stream_id1 = ret_event1.internal_metadata.stream_ordering
+
+ self.assertEqual(event1.event_id, ret_event1.event_id)
+
+ event2, context = self._create_duplicate_event(txn_id)
+
+ # We want to test that the deduplication at the persit event end works,
+ # so we want to make sure we test with different events.
+ self.assertNotEqual(event1.event_id, event2.event_id)
+
+ ret_event2 = self.get_success(
+ self.handler.handle_new_client_event(self.requester, event2, context)
+ )
+ stream_id2 = ret_event2.internal_metadata.stream_ordering
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event2.event_id)
+ self.assertEqual(stream_id1, stream_id2)
+
+ # Let's test that calling `persist_event` directly also does the right
+ # thing.
+ event3, context = self._create_duplicate_event(txn_id)
+ self.assertNotEqual(event1.event_id, event3.event_id)
+
+ ret_event3, event_pos3, _ = self.get_success(
+ self.persist_event_storage.persist_event(event3, context)
+ )
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event3.event_id)
+ self.assertEqual(stream_id1, event_pos3.stream)
+
+ # Let's test that calling `persist_events` directly also does the right
+ # thing.
+ event4, context = self._create_duplicate_event(txn_id)
+ self.assertNotEqual(event1.event_id, event3.event_id)
+
+ events, _ = self.get_success(
+ self.persist_event_storage.persist_events([(event3, context)])
+ )
+ ret_event4 = events[0]
+
+ # Assert that the returned values match those from the initial event
+ # rather than the new one.
+ self.assertEqual(ret_event1.event_id, ret_event4.event_id)
+
+ def test_duplicated_txn_id_one_call(self):
+ """Test that we correctly handle duplicates that we try and persist at
+ the same time.
+ """
+
+ txn_id = "something_else_suitably_random"
+
+ # Create two duplicate events to persist at the same time
+ event1, context1 = self._create_duplicate_event(txn_id)
+ event2, context2 = self._create_duplicate_event(txn_id)
+
+ # Ensure their event IDs are different to start with
+ self.assertNotEqual(event1.event_id, event2.event_id)
+
+ events, _ = self.get_success(
+ self.persist_event_storage.persist_events(
+ [(event1, context1), (event2, context2)]
+ )
+ )
+
+ # Check that we've deduplicated the events.
+ self.assertEqual(len(events), 2)
+ self.assertEqual(events[0].event_id, events[1].event_id)
+
+
+class ServerAclValidationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ self.user_id = self.register_user("tester", "foobar")
+ self.access_token = self.login("tester", "foobar")
+ self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
+
+ def test_allow_server_acl(self):
+ """Test that sending an ACL that blocks everyone but ourselves works.
+ """
+
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={"allow": [self.hs.hostname]},
+ tok=self.access_token,
+ expect_code=200,
+ )
+
+ def test_deny_server_acl_block_outselves(self):
+ """Test that sending an ACL that blocks ourselves does not work.
+ """
+ self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={},
+ tok=self.access_token,
+ expect_code=400,
+ )
+
+ def test_deny_redact_server_acl(self):
+ """Test that attempting to redact an ACL is blocked.
+ """
+
+ body = self.helper.send_state(
+ self.room_id,
+ EventTypes.ServerACL,
+ body={"allow": [self.hs.hostname]},
+ tok=self.access_token,
+ expect_code=200,
+ )
+ event_id = body["event_id"]
+
+ # Redaction of event should fail.
+ path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
+ request, channel = self.make_request(
+ "POST", path, content={}, access_token=self.access_token
+ )
+ self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d5087e58be..a308c46da9 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import json
from urllib.parse import parse_qs, urlparse
@@ -24,12 +23,8 @@ import pymacaroons
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
-from synapse.handlers.oidc_handler import (
- MappingException,
- OidcError,
- OidcHandler,
- OidcMappingProvider,
-)
+from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
+from synapse.handlers.sso import MappingException
from synapse.types import UserID
from tests.unittest import HomeserverTestCase, override_config
@@ -94,6 +89,14 @@ class TestMappingProviderExtra(TestMappingProvider):
return {"phone": userinfo["phone"]}
+class TestMappingProviderFailures(TestMappingProvider):
+ async def map_user_attributes(self, userinfo, token, failures):
+ return {
+ "localpart": userinfo["username"] + (str(failures) if failures else ""),
+ "display_name": None,
+ }
+
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
async def cb(*args, **kwargs):
@@ -124,22 +127,16 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
- def make_homeserver(self, reactor, clock):
-
- self.http_client = Mock(spec=["get_json"])
- self.http_client.get_json.side_effect = get_json
- self.http_client.user_agent = "Synapse Test"
-
- config = self.default_config()
+ def default_config(self):
+ config = super().default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = {}
- oidc_config["enabled"] = True
- oidc_config["client_id"] = CLIENT_ID
- oidc_config["client_secret"] = CLIENT_SECRET
- oidc_config["issuer"] = ISSUER
- oidc_config["scopes"] = SCOPES
- oidc_config["user_mapping_provider"] = {
- "module": __name__ + ".TestMappingProvider",
+ oidc_config = {
+ "enabled": True,
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "issuer": ISSUER,
+ "scopes": SCOPES,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
}
# Update this config with what's in the default config so that
@@ -147,13 +144,24 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
- hs = self.setup_test_homeserver(
- http_client=self.http_client,
- proxied_http_client=self.http_client,
- config=config,
- )
+ return config
- self.handler = OidcHandler(hs)
+ def make_homeserver(self, reactor, clock):
+
+ self.http_client = Mock(spec=["get_json"])
+ self.http_client.get_json.side_effect = get_json
+ self.http_client.user_agent = "Synapse Test"
+
+ hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+
+ self.handler = hs.get_oidc_handler()
+ sso_handler = hs.get_sso_handler()
+ # Mock the render error method.
+ self.render_error = Mock(return_value=None)
+ sso_handler.render_error = self.render_error
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler._MAP_USERNAME_RETRIES = 3
return hs
@@ -161,12 +169,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
return patch.dict(self.handler._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
- args = self.handler._render_error.call_args[0]
+ args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
- self.handler._render_error.reset_mock()
+ self.render_error.reset_mock()
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
@@ -286,9 +294,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
h._validate_metadata,
)
- # Tests for configs that the userinfo endpoint
+ # Tests for configs that require the userinfo endpoint
self.assertFalse(h._uses_userinfo)
- h._scopes = [] # do not request the openid scope
+ self.assertEqual(h._user_profile_method, "auto")
+ h._user_profile_method = "userinfo_endpoint"
+ self.assertTrue(h._uses_userinfo)
+
+ # Revert the profile method and do not request the "openid" scope.
+ h._user_profile_method = "auto"
+ h._scopes = []
self.assertTrue(h._uses_userinfo)
self.assertRaisesRegex(ValueError, "userinfo_endpoint", h._validate_metadata)
@@ -350,7 +364,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
- self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
self.get_success(self.handler.handle_oidc_callback(request))
@@ -381,14 +394,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
"preferred_username": "bar",
}
user_id = "@foo:domain.org"
- self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
code = "code"
@@ -408,9 +427,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [code.encode("utf-8")]
request.args[b"state"] = [state.encode("utf-8")]
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
+ request.get_user_agent.return_value = user_agent
self.get_success(self.handler.handle_oidc_callback(request))
@@ -423,7 +441,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_not_called()
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle mapping errors
self.handler._map_userinfo_to_user = simple_async_mock(
@@ -457,7 +475,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, user_agent, ip_address
)
self.handler._fetch_userinfo.assert_called_once_with(token)
- self.handler._render_error.assert_not_called()
+ self.render_error.assert_not_called()
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
@@ -473,7 +491,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
- self.handler._render_error = Mock(return_value=None)
request = Mock(spec=["args", "getCookie", "addCookie"])
# Missing cookie
@@ -615,7 +632,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
self.handler._auth_handler.complete_sso_login = simple_async_mock()
request = Mock(
- spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "get_user_agent",
+ ]
)
state = "state"
@@ -631,9 +655,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args[b"code"] = [b"code"]
request.args[b"state"] = [state.encode("utf-8")]
- request.requestHeaders = Mock(spec=["getRawHeaders"])
- request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
request.getClientIP.return_value = "10.0.0.1"
+ request.get_user_agent.return_value = "Browser"
self.get_success(self.handler.handle_oidc_callback(request))
@@ -681,19 +704,131 @@ class OidcHandlerTestCase(HomeserverTestCase):
),
MappingException,
)
- self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken")
+ self.assertEqual(
+ str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ )
@override_config({"oidc_config": {"allow_existing_users": True}})
def test_map_userinfo_to_existing_user(self):
"""Existing users can log in with OpenID Connect when allow_existing_users is True."""
store = self.hs.get_datastore()
- user4 = UserID.from_string("@test_user_4:test")
+ user = UserID.from_string("@test_user:test")
+ self.get_success(
+ store.register_user(user_id=user.to_string(), password_hash=None)
+ )
+
+ # Map a user via SSO.
+ userinfo = {
+ "sub": "test",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Note that a second SSO user can be mapped to the same Matrix ID. (This
+ # requires a unique sub, but something that maps to the same matrix ID,
+ # in this case we'll just use the same username. A more realistic example
+ # would be subs which are email addresses, and mapping from the localpart
+ # of the email, e.g. bob@foo.com and bob@bar.com -> @bob:test.)
+ userinfo = {
+ "sub": "test1",
+ "username": "test_user",
+ }
+ token = {}
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Register some non-exact matching cases.
+ user2 = UserID.from_string("@TEST_user_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+ user2_caps = UserID.from_string("@test_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2_caps.to_string(), password_hash=None)
+ )
+
+ # Attempting to login without matching a name exactly is an error.
+ userinfo = {
+ "sub": "test2",
+ "username": "TEST_USER_2",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertTrue(
+ str(e.value).startswith(
+ "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
+ )
+ )
+
+ # Logging in when matching a name exactly should work.
+ user2 = UserID.from_string("@TEST_USER_2:test")
+ self.get_success(
+ store.register_user(user_id=user2.to_string(), password_hash=None)
+ )
+
+ mxid = self.get_success(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@TEST_USER_2:test")
+
+ def test_map_userinfo_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ userinfo = {
+ "sub": "test2",
+ "username": "föö",
+ }
+ token = {}
+
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderFailures"
+ }
+ }
+ }
+ )
+ def test_map_userinfo_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
self.get_success(
- store.register_user(user_id=user4.to_string(), password_hash=None)
+ store.register_user(user_id="@test_user:test", password_hash=None)
)
userinfo = {
- "sub": "test4",
- "username": "test_user_4",
+ "sub": "test",
+ "username": "test_user",
}
token = {}
mxid = self.get_success(
@@ -701,4 +836,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
userinfo, token, "user-agent", "10.10.10.10"
)
)
- self.assertEqual(mxid, "@test_user_4:test")
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential mxids for a particular OIDC username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ userinfo = {
+ "sub": "tester",
+ "username": "tester",
+ }
+ e = self.get_failure(
+ self.handler._map_userinfo_to_user(
+ userinfo, token, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
new file mode 100644
index 0000000000..ceaf0902d2
--- /dev/null
+++ b/tests/handlers/test_password_providers.py
@@ -0,0 +1,580 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for the password_auth_provider interface"""
+
+from typing import Any, Type, Union
+
+from mock import Mock
+
+from twisted.internet import defer
+
+import synapse
+from synapse.rest.client.v1 import login
+from synapse.rest.client.v2_alpha import devices
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel
+from tests.unittest import override_config
+
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
+# a mock instance which the dummy auth providers delegate to, so we can see what's going
+# on
+mock_password_provider = Mock()
+
+
+class PasswordOnlyAuthProvider:
+ """A password_provider which only implements `check_password`."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def check_password(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+class CustomAuthProvider:
+ """A password_provider which implements a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class PasswordCustomAuthProvider:
+ """A password_provider which implements password login via `check_auth`, as well
+ as a custom type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, account_handler):
+ pass
+
+ def get_supported_login_types(self):
+ return {"m.login.password": ["password"], "test.login_type": ["test_field"]}
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given password auth providers"""
+ return {
+ "password_providers": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
+class PasswordAuthProviderTests(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets,
+ login.register_servlets,
+ devices.register_servlets,
+ ]
+
+ def setUp(self):
+ # we use a global mock device, so make sure we are starting with a clean slate
+ mock_password_provider.reset_mock()
+ super().setUp()
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_login(self):
+ # login flows should only have m.login.password
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # login with mxid should work too
+ channel = self._send_password_login("@u:bz", "p")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@u:bz", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with("@u:bz", "p")
+ mock_password_provider.reset_mock()
+
+ # try a weird username / pass. Honestly it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"])
+ mock_password_provider.check_password.assert_called_once_with(
+ "@ USER🙂NAME :test", " pASS😢word "
+ )
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth(self):
+ """UI Auth should delegate correctly to the password provider"""
+
+ # create the user, otherwise access doesn't work
+ module_api = self.hs.get_module_api()
+ self.get_success(module_api.register_user("u"))
+
+ # log in twice, to get two devices
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ tok1 = self.login("u", "p")
+ self.login("u", "p", device_id="dev2")
+ mock_password_provider.reset_mock()
+
+ # have the auth provider deny the request to start with
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # make the initial request which returns a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Make another request providing the UI auth flow.
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+ mock_password_provider.reset_mock()
+
+ # Finally, check the request goes through when we allow it
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+ channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_login(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 403, channel.result)
+
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@localuser:test", channel.json_body["user_id"])
+
+ @override_config(providers_config(PasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth(self):
+ """rejected login should fall back to local db"""
+ self.register_user("localuser", "localpass")
+
+ # have the auth provider deny the request
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ session = self._start_delete_device_session(tok1, "dev2")
+ mock_password_provider.check_password.assert_not_called()
+
+ # Wrong password
+ channel = self._authed_delete_device(tok1, "dev2", session, "localuser", "xxx")
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "xxx"
+ )
+ mock_password_provider.reset_mock()
+
+ # Right password
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_login(self):
+ """localdb_enabled can block login with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # check_password must return an awaitable
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_no_local_user_fallback_ui_auth(self):
+ """localdb_enabled can block ui auth with the local password
+ """
+ self.register_user("localuser", "localpass")
+
+ # allow login via the auth provider
+ mock_password_provider.check_password.return_value = defer.succeed(True)
+
+ # log in twice, to get two devices
+ tok1 = self.login("localuser", "p")
+ self.login("localuser", "p", device_id="dev2")
+ mock_password_provider.check_password.reset_mock()
+
+ # first delete should give a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # m.login.password UIA is permitted because the auth provider allows it,
+ # even though the localdb does not.
+ self.assertEqual(channel.json_body["flows"], [{"stages": ["m.login.password"]}])
+ session = channel.json_body["session"]
+ mock_password_provider.check_password.assert_not_called()
+
+ # now try deleting with the local password
+ mock_password_provider.check_password.return_value = defer.succeed(False)
+ channel = self._authed_delete_device(
+ tok1, "dev2", session, "localuser", "localpass"
+ )
+ self.assertEqual(channel.code, 401) # XXX why not a 403?
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_password.assert_called_once_with(
+ "@localuser:test", "localpass"
+ )
+
+ @override_config(
+ {
+ **providers_config(PasswordOnlyAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_auth_disabled(self):
+ """password auth doesn't work if it's disabled across the board"""
+ # login flows should be empty
+ flows = self._get_login_flows()
+ self.assertEqual(flows, ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("u", "p")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_login(self):
+ # login flows should have the custom flow and m.login.password, since we
+ # haven't disabled local password lookup.
+ # (password must come first, because reasons)
+ flows = self._get_login_flows()
+ self.assertEqual(
+ flows,
+ [{"type": "m.login.password"}, {"type": "test.login_type"}]
+ + ADDITIONAL_LOGIN_FLOWS,
+ )
+
+ # login with missing param should be rejected
+ channel = self._send_login("test.login_type", "u")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+ mock_password_provider.reset_mock()
+
+ # try a weird username. Again, it's unclear what we *expect* to happen
+ # in these cases, but at least we can guard against the API changing
+ # unexpectedly
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@ MALFORMED! :bz"
+ )
+ channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ " USER🙂NAME ", "test.login_type", {"test_field": " abc "}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_ui_auth(self):
+ # register the user and log in twice, to get two devices
+ self.register_user("localuser", "localpass")
+ tok1 = self.login("localuser", "localpass")
+ self.login("localuser", "localpass", device_id="dev2")
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ # missing param
+ body = {
+ "auth": {
+ "type": "test.login_type",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ # there's a perfectly good M_MISSING_PARAM errcode, but heaven forfend we should
+ # use it...
+ self.assertIn("Missing parameters", channel.json_body["error"])
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # right params, but authing as the wrong user
+ mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ body["auth"]["test_field"] = "foo"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 403)
+ self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+ mock_password_provider.reset_mock()
+
+ # and finally, succeed
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "foo"}
+ )
+
+ @override_config(providers_config(CustomAuthProvider))
+ def test_custom_auth_provider_callback(self):
+ callback = Mock(return_value=defer.succeed(None))
+
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", callback)
+ )
+ channel = self._send_login("test.login_type", "u", test_field="y")
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual("@user:bz", channel.json_body["user_id"])
+ mock_password_provider.check_auth.assert_called_once_with(
+ "u", "test.login_type", {"test_field": "y"}
+ )
+
+ # check the args to the callback
+ callback.assert_called_once()
+ call_args, call_kwargs = callback.call_args
+ # should be one positional arg
+ self.assertEqual(len(call_args), 1)
+ self.assertEqual(call_args[0]["user_id"], "@user:bz")
+ for p in ["user_id", "access_token", "device_id", "home_server"]:
+ self.assertIn(p, call_args[0])
+
+ @override_config(
+ {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
+ )
+ def test_custom_auth_password_disabled(self):
+ """Test login with a custom auth provider where password login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login(self):
+ """log in with a custom auth provider which implements password, but password
+ login is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
+ **providers_config(PasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth(self):
+ """UI Auth with a custom auth provider which implements password, but password
+ login is disabled"""
+ # register the user and log in twice via the test login type to get two devices,
+ self.register_user("localuser", "localpass")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ "@localuser:test"
+ )
+ channel = self._send_login("test.login_type", "localuser", test_field="")
+ self.assertEqual(channel.code, 200, channel.result)
+ tok1 = channel.json_body["access_token"]
+
+ channel = self._send_login(
+ "test.login_type", "localuser", test_field="", device_id="dev2"
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # make the initial request which returns a 401
+ channel = self._delete_device(tok1, "dev2")
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected. In particular, "password" should *not*
+ # be present.
+ self.assertIn({"stages": ["test.login_type"]}, channel.json_body["flows"])
+ session = channel.json_body["session"]
+
+ mock_password_provider.reset_mock()
+
+ # check that auth with password is rejected
+ body = {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": "localuser"},
+ "password": "localpass",
+ "session": session,
+ },
+ }
+
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 400)
+ self.assertEqual(
+ "Password login has been disabled.", channel.json_body["error"]
+ )
+ mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.reset_mock()
+
+ # successful auth
+ body["auth"]["type"] = "test.login_type"
+ body["auth"]["test_field"] = "x"
+ channel = self._delete_device(tok1, "dev2", body)
+ self.assertEqual(channel.code, 200)
+ mock_password_provider.check_auth.assert_called_once_with(
+ "localuser", "test.login_type", {"test_field": "x"}
+ )
+
+ @override_config(
+ {
+ **providers_config(CustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback(self):
+ """Test login with a custom auth provider where the local db is disabled"""
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # password login shouldn't work and should be rejected with a 400
+ # ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+
+ def _get_login_flows(self) -> JsonDict:
+ _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+ return channel.json_body["flows"]
+
+ def _send_password_login(self, user: str, password: str) -> FakeChannel:
+ return self._send_login(type="m.login.password", user=user, password=password)
+
+ def _send_login(self, type, user, **params) -> FakeChannel:
+ params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
+ _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ return channel
+
+ def _start_delete_device_session(self, access_token, device_id) -> str:
+ """Make an initial delete device request, and return the UI Auth session ID"""
+ channel = self._delete_device(access_token, device_id)
+ self.assertEqual(channel.code, 401)
+ # Ensure that flows are what is expected.
+ self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
+ return channel.json_body["session"]
+
+ def _authed_delete_device(
+ self,
+ access_token: str,
+ device_id: str,
+ session: str,
+ user_id: str,
+ password: str,
+ ) -> FakeChannel:
+ """Make a delete device request, authenticating with the given uid/password"""
+ return self._delete_device(
+ access_token,
+ device_id,
+ {
+ "auth": {
+ "type": "m.login.password",
+ "identifier": {"type": "m.id.user", "user": user_id},
+ "password": password,
+ "session": session,
+ },
+ },
+ )
+
+ def _delete_device(
+ self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
+ ) -> FakeChannel:
+ """Delete an individual device."""
+ _, channel = self.make_request(
+ "DELETE", "devices/" + device, body, access_token=access_token
+ )
+ return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 306dcfe944..8ed67640f8 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -470,7 +470,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.federation_sender = hs.get_federation_sender()
self.event_builder_factory = hs.get_event_builder_factory()
- self.federation_handler = hs.get_handlers().federation_handler
+ self.federation_handler = hs.get_federation_handler()
self.presence_handler = hs.get_presence_handler()
# self.event_builder_for_2 = EventBuilderFactory(hs)
@@ -615,7 +615,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
self.store.get_latest_event_ids_in_room(room_id)
)
- event = self.get_success(builder.build(prev_event_ids))
+ event = self.get_success(builder.build(prev_event_ids, None))
self.get_success(self.federation_handler.on_receive_pdu(hostname, event))
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 8e95e53d9e..a69fa28b41 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -20,7 +20,6 @@ from twisted.internet import defer
import synapse.types
from synapse.api.errors import AuthError, SynapseError
-from synapse.handlers.profile import MasterProfileHandler
from synapse.types import UserID
from tests import unittest
@@ -28,11 +27,6 @@ from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
-class ProfileHandlers:
- def __init__(self, hs):
- self.profile_handler = MasterProfileHandler(hs)
-
-
class ProfileTestCase(unittest.TestCase):
""" Tests profile management. """
@@ -51,7 +45,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
http_client=None,
- handlers=None,
resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index cb7c0ed51a..bdf3d0a8a2 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -18,7 +18,6 @@ from mock import Mock
from synapse.api.auth import Auth
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, ResourceLimitError, SynapseError
-from synapse.handlers.register import RegistrationHandler
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import RoomAlias, UserID, create_requester
@@ -29,11 +28,6 @@ from tests.utils import mock_getRawHeaders
from .. import unittest
-class RegistrationHandlers:
- def __init__(self, hs):
- self.registration_handler = RegistrationHandler(hs)
-
-
class RegistrationTestCase(unittest.HomeserverTestCase):
""" Tests the RegistrationHandler. """
@@ -154,7 +148,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
room_alias_str = "#room:test"
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -193,7 +187,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
self.get_failure(directory_handler.get_association(room_alias), SynapseError)
@@ -205,7 +199,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.store.is_real_user = Mock(return_value=make_awaitable(True))
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -237,7 +231,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -266,7 +260,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -304,7 +298,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -347,7 +341,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -384,7 +378,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
user_id = self.get_success(self.handler.register_user(localpart="jeff"))
# Ensure the room was created.
- directory_handler = self.hs.get_handlers().directory_handler
+ directory_handler = self.hs.get_directory_handler()
room_alias = RoomAlias.from_string(room_alias_str)
room_id = self.get_success(directory_handler.get_association(room_alias))
@@ -413,7 +407,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
- event_creation_handler.send_nonmember_event(requester, event, context)
+ event_creation_handler.handle_new_client_event(requester, event, context)
)
# Register a second user, which won't be be in the room (or even have an invite)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
new file mode 100644
index 0000000000..45dc17aba5
--- /dev/null
+++ b/tests/handlers/test_saml.py
@@ -0,0 +1,196 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers.sso import MappingException
+
+from tests.unittest import HomeserverTestCase, override_config
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+
+
+@attr.s
+class FakeAuthnResponse:
+ ava = attr.ib(type=dict)
+
+
+class TestMappingProvider:
+ def __init__(self, config, module):
+ pass
+
+ @staticmethod
+ def parse_config(config):
+ return
+
+ @staticmethod
+ def get_saml_attributes(config):
+ return {"uid"}, {"displayName"}
+
+ def get_remote_user_id(self, saml_response, client_redirect_url):
+ return saml_response.ava["uid"]
+
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ localpart = saml_response.ava["username"] + (str(failures) if failures else "")
+ return {"mxid_localpart": localpart, "displayname": None}
+
+
+class TestRedirectMappingProvider(TestMappingProvider):
+ def saml_response_to_user_attributes(
+ self, saml_response, failures, client_redirect_url
+ ):
+ raise RedirectException(b"https://custom-saml-redirect/")
+
+
+class SamlHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ saml_config = {
+ "sp_config": {"metadata": {}},
+ # Disable grandfathering.
+ "grandfathered_mxid_source_attribute": None,
+ "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+ }
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ saml_config.update(config.get("saml2_config", {}))
+ config["saml2_config"] = saml_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_saml_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_saml_response_to_user(self):
+ """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+ saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
+ # The redirect_url doesn't matter with the default user mapping provider.
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
+ def test_map_saml_response_to_existing_user(self):
+ """Existing users can log in with SAML account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # Map a user via SSO.
+ saml_response = FakeAuthnResponse(
+ {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
+ )
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ # Subsequent calls should map to the same mxid.
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ self.assertEqual(mxid, "@test_user:test")
+
+ def test_map_saml_response_to_invalid_localpart(self):
+ """If the mapping provider generates an invalid localpart it should be rejected."""
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
+ redirect_url = ""
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(str(e.value), "localpart is invalid: föö")
+
+ def test_map_saml_response_to_user_retries(self):
+ """The mapping provider can retry generating an MXID if the MXID is already in use."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ redirect_url = ""
+ mxid = self.get_success(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ )
+ )
+ # test_user is already taken, so test_user1 gets registered instead.
+ self.assertEqual(mxid, "@test_user1:test")
+
+ # Register all of the potential mxids for a particular SAML username.
+ self.get_success(
+ store.register_user(user_id="@tester:test", password_hash=None)
+ )
+ for i in range(1, 3):
+ self.get_success(
+ store.register_user(user_id="@tester%d:test" % i, password_hash=None)
+ )
+
+ # Now attempt to map to a username, this will fail since all potential usernames are taken.
+ saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ MappingException,
+ )
+ self.assertEqual(
+ str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ )
+
+ @override_config(
+ {
+ "saml2_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestRedirectMappingProvider"
+ },
+ }
+ }
+ )
+ def test_map_saml_response_redirect(self):
+ saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
+ redirect_url = ""
+ e = self.get_failure(
+ self.handler._map_saml_response_to_user(
+ saml_response, redirect_url, "user-agent", "10.10.10.10"
+ ),
+ RedirectException,
+ )
+ self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index e178d7765b..e62586142e 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test"
user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1)
+ requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1))
- self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
+ self.get_success(
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config)
+ )
# Test that global lock works
self.auth_blocking._hs_disabled = True
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2)
+ requester = create_requester(user_id2)
e = self.get_failure(
- self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
+ self.sync_handler.wait_for_sync_for_user(requester, sync_config),
+ ResourceLimitError,
)
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 3fec09ea8a..abbdf2d524 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -65,26 +65,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
mock_federation_client = Mock(spec=["put_json"])
mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
- datastores = Mock()
- datastores.main = Mock(
- spec=[
- # Bits that Federation needs
- "prep_send_transaction",
- "delivered_txn",
- "get_received_txn_response",
- "set_received_txn_response",
- "get_destination_last_successful_stream_ordering",
- "get_destination_retry_timings",
- "get_devices_by_remote",
- "maybe_store_room_on_invite",
- # Bits that user_directory needs
- "get_user_directory_stream_pos",
- "get_current_state_deltas",
- "get_device_updates_by_remote",
- "get_room_max_stream_ordering",
- ]
- )
-
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -95,8 +75,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
replication_streams={},
)
- hs.datastores = datastores
-
return hs
def prepare(self, reactor, clock, hs):
@@ -114,16 +92,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
"retry_interval": 0,
"failure_ts": None,
}
- self.datastore.get_destination_retry_timings.return_value = defer.succeed(
- retry_timings_res
+ self.datastore.get_destination_retry_timings = Mock(
+ return_value=defer.succeed(retry_timings_res)
)
- self.datastore.get_device_updates_by_remote.return_value = make_awaitable(
- (0, [])
+ self.datastore.get_device_updates_by_remote = Mock(
+ return_value=make_awaitable((0, []))
)
- self.datastore.get_destination_last_successful_stream_ordering.return_value = make_awaitable(
- None
+ self.datastore.get_destination_last_successful_stream_ordering = Mock(
+ return_value=make_awaitable(None)
)
def get_received_txn_response(*args):
@@ -145,17 +123,19 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
- def get_users_in_room(room_id):
- return defer.succeed({str(u) for u in self.room_members})
+ async def get_users_in_room(room_id):
+ return {str(u) for u in self.room_members}
self.datastore.get_users_in_room = get_users_in_room
- self.datastore.get_user_directory_stream_pos.side_effect = (
- # we deliberately return a non-None stream pos to avoid doing an initial_spam
- lambda: make_awaitable(1)
+ self.datastore.get_user_directory_stream_pos = Mock(
+ side_effect=(
+ # we deliberately return a non-None stream pos to avoid doing an initial_spam
+ lambda: make_awaitable(1)
+ )
)
- self.datastore.get_current_state_deltas.return_value = (0, None)
+ self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: make_awaitable(
@@ -248,7 +228,6 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
federation_auth_origin=b"farm",
)
- self.render(request)
self.assertEqual(channel.code, 200)
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 87be94111f..98e5af2072 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -537,7 +537,6 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) > 0)
@@ -546,6 +545,5 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
request, channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
- self.render(request)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
|