diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..46d022092e 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection
from tests.unittest import override_config
from tests.utils import MockClock
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.mock_store = Mock()
- self.mock_as_api = Mock()
+ self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
- self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
- None
- )
+ self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested_in_event=False),
]
- self.mock_as_api.query_user.return_value = make_awaitable(True)
+ self.mock_as_api.query_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable([])
+ self.mock_store.get_user_by_id = AsyncMock(return_value=[])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {})),
- make_awaitable((1, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [
- make_awaitable([]),
- make_awaitable([event]),
- ]
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {}),
+ (1, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(
+ side_effect=[
+ [],
+ [event],
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+ self.mock_store.get_user_by_id = AsyncMock(return_value=None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+ self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
- ]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, [event], {event.event_id: 0}),
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_room_alias_in_namespace=False),
]
- self.mock_as_api.query_alias.return_value = make_awaitable(True)
+ self.mock_as_api.query_alias = AsyncMock(return_value=True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
- Mock(room_id=room_id, servers=servers)
+ self.mock_store.get_association_from_room_alias = AsyncMock(
+ return_value=Mock(room_id=room_id, servers=servers)
)
result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+ self.mock_as_api.get_3pe_protocol.return_value = None
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
)
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 579
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 580
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
A mock representing the ApplicationService.
"""
service = Mock()
- service.is_interested_in_event.return_value = make_awaitable(
- is_interested_in_event
- )
+ service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
@@ -396,12 +399,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
- self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
+ self.send_mock = AsyncMock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign]
return_value=self._services
)
@@ -894,12 +897,12 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
- self.put_json = simple_async_mock()
- hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
+ self.put_json = AsyncMock()
+ hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign]
return_value=self._services
)
@@ -1000,8 +1003,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
- self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
+ self.send_mock = AsyncMock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method.
# Define an application service for the tests
self._service_token = "VERYSECRET"
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 036dbbc45b..413ff8795b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
import pymacaroons
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
class AuthTestCase(unittest.HomeserverTestCase):
@@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_exceeded_large(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
- self.hs.get_datastores().main.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.large_number_of_users)
+ self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+ return_value=self.large_number_of_users
)
self.get_failure(
@@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- self.hs.get_datastores().main.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.large_number_of_users)
+ self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+ return_value=self.large_number_of_users
)
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1)
@@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.auth_blocking._limit_usage_by_mau = True
# Set the server to be at the edge of too many users.
- self.hs.get_datastores().main.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.auth_blocking._max_mau_value)
+ self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+ return_value=self.auth_blocking._max_mau_value
)
# If not in monthly active cohort
@@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertIsNone(self.token_login(token))
# If in monthly active cohort
- self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.clock.time_msec())
+ self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock(
+ return_value=self.clock.time_msec()
)
self.get_success(
self.auth_handler.create_access_token_for_user_id(
@@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
def test_mau_limits_not_exceeded(self) -> None:
self.auth_blocking._limit_usage_by_mau = True
- self.hs.get_datastores().main.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.small_number_of_users)
+ self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+ return_value=self.small_number_of_users
)
# Ensure does not raise exception
self.get_success(
@@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
- self.hs.get_datastores().main.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.small_number_of_users)
+ self.hs.get_datastores().main.get_monthly_active_count = AsyncMock(
+ return_value=self.small_number_of_users
)
token = self.get_success(
self.auth_handler.create_login_token_for_user_id(self.user1)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index 63aad0d10c..8582b1cd1e 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
from synapse.server import HomeServer
from synapse.util import Clock
-from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# These are a few constants that are used as config parameters in the tests.
@@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
cas_response = CasResponse("test_user", {})
request = _mock_request()
@@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# Map a user via SSO.
cas_response = CasResponse("test_user", {})
@@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
cas_response = CasResponse("föö", {})
request = _mock_request()
@@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# The response doesn't have the proper userGroup or department.
cas_response = CasResponse("test_user", {})
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index e1e58fa6e6..55a4f95ef3 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -32,7 +32,6 @@ from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
user1 = "@boris:aaa"
@@ -41,7 +40,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.appservice_api = mock.Mock()
+ self.appservice_api = mock.AsyncMock()
hs = self.setup_test_homeserver(
"server",
application_service_api=self.appservice_api,
@@ -123,50 +122,50 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, len(res))
device_map = {d["device_id"]: d for d in res}
- self.assertDictContainsSubset(
+ self.assertLessEqual(
{
"user_id": user1,
"device_id": "xyz",
"display_name": "display 0",
"last_seen_ip": None,
"last_seen_ts": None,
- },
- device_map["xyz"],
+ }.items(),
+ device_map["xyz"].items(),
)
- self.assertDictContainsSubset(
+ self.assertLessEqual(
{
"user_id": user1,
"device_id": "fco",
"display_name": "display 1",
"last_seen_ip": "ip1",
"last_seen_ts": 1000000,
- },
- device_map["fco"],
+ }.items(),
+ device_map["fco"].items(),
)
- self.assertDictContainsSubset(
+ self.assertLessEqual(
{
"user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
- },
- device_map["abc"],
+ }.items(),
+ device_map["abc"].items(),
)
def test_get_device(self) -> None:
self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc"))
- self.assertDictContainsSubset(
+ self.assertLessEqual(
{
"user_id": user1,
"device_id": "abc",
"display_name": "display 2",
"last_seen_ip": "ip3",
"last_seen_ts": 3000000,
- },
- res,
+ }.items(),
+ res.items(),
)
def test_delete_device(self) -> None:
@@ -375,13 +374,11 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
- self.appservice_api.query_keys.return_value = make_awaitable(
- {
- "device_keys": {
- local_user: {device_2: device_key_2b, device_3: device_key_3}
- }
+ self.appservice_api.query_keys.return_value = {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
}
- )
+ }
# Request all devices.
res = self.get_success(
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 90aec484c4..367d94eca3 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
class DirectoryTestCase(unittest.HomeserverTestCase):
"""Tests the directory service."""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.mock_federation = Mock()
+ self.mock_federation = AsyncMock()
self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result)
def test_get_remote_association(self) -> None:
- self.mock_federation.make_query.return_value = make_awaitable(
- {"room_id": "!8765qwer:test", "servers": ["test", "remote"]}
- )
+ self.mock_federation.make_query.return_value = {
+ "room_id": "!8765qwer:test",
+ "servers": ["test", "remote"],
+ }
result = self.get_success(self.handler.get_association(self.remote_room))
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 2eaffe511e..c5556f2844 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,7 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Iterable
+from typing import Dict, Iterable
from unittest import mock
from parameterized import parameterized
@@ -31,13 +31,12 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.appservice_api = mock.Mock()
+ self.appservice_api = mock.AsyncMock()
return self.setup_test_homeserver(
federation_client=mock.Mock(), application_service_api=self.appservice_api
)
@@ -801,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment]
- return_value=make_awaitable(
- {
- "device_keys": {remote_user_id: {}},
- "master_keys": {
- remote_user_id: {
- "user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- },
- "self_signing_keys": {
- remote_user_id: {
- "user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:"
- + remote_self_signing_key: remote_self_signing_key
- },
- }
+ self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "device_keys": {remote_user_id: {}},
+ "master_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
},
- }
- )
+ },
+ "self_signing_keys": {
+ remote_user_id: {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:"
+ + remote_self_signing_key: remote_self_signing_key
+ },
+ }
+ },
+ }
)
e2e_handler = self.hs.get_e2e_keys_handler()
@@ -874,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
- self.store.get_rooms_for_user = mock.Mock(
- return_value=make_awaitable({"some_room_id"})
- )
+ self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"})
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
- self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment]
- return_value=make_awaitable(
- {
+ self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign]
+ return_value={
+ "user_id": remote_user_id,
+ "stream_id": 1,
+ "devices": [],
+ "master_key": {
"user_id": remote_user_id,
- "stream_id": 1,
- "devices": [],
- "master_key": {
- "user_id": remote_user_id,
- "usage": ["master"],
- "keys": {"ed25519:" + remote_master_key: remote_master_key},
- },
- "self_signing_key": {
- "user_id": remote_user_id,
- "usage": ["self_signing"],
- "keys": {
- "ed25519:"
- + remote_self_signing_key: remote_self_signing_key
- },
+ "usage": ["master"],
+ "keys": {"ed25519:" + remote_master_key: remote_master_key},
+ },
+ "self_signing_key": {
+ "user_id": remote_user_id,
+ "usage": ["self_signing"],
+ "keys": {
+ "ed25519:" + remote_self_signing_key: remote_self_signing_key
},
- }
- )
+ },
+ }
)
e2e_handler = self.hs.get_e2e_keys_handler()
@@ -987,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
mock_get_rooms = mock.patch.object(
self.store,
"get_rooms_for_user",
- new_callable=mock.MagicMock,
- return_value=make_awaitable(["some_room_id"]),
+ new_callable=mock.AsyncMock,
+ return_value=["some_room_id"],
)
mock_get_users = mock.patch.object(
self.store,
"get_users_server_still_shares_room_with",
- new_callable=mock.MagicMock,
- return_value=make_awaitable({remote_user_id}),
+ new_callable=mock.AsyncMock,
+ return_value={remote_user_id},
)
mock_request = mock.patch.object(
self.hs.get_federation_client(),
"query_user_devices",
- new_callable=mock.MagicMock,
- return_value=make_awaitable(response_body),
+ new_callable=mock.AsyncMock,
+ return_value=response_body,
)
with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request:
@@ -1060,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response, but only for device 2.
- self.appservice_api.claim_client_keys.return_value = make_awaitable(
- ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)])
+ self.appservice_api.claim_client_keys.return_value = (
+ {local_user: {device_id_2: otk}},
+ [(local_user, device_id_1, "alg1", 1)],
)
# we shouldn't have any unused fallback keys yet
@@ -1127,9 +1120,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
- self.appservice_api.claim_client_keys.return_value = make_awaitable(
- ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, [])
- )
+ response: Dict[str, Dict[str, Dict[str, JsonDict]]] = {
+ local_user: {device_id_1: {**as_otk, **as_fallback_key}}
+ }
+ self.appservice_api.claim_client_keys.return_value = (response, [])
# Claim OTKs, which will ask the appservice and do nothing else.
claim_res = self.get_success(
@@ -1171,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"])
# The appservice will return only the OTK.
- self.appservice_api.claim_client_keys.return_value = make_awaitable(
- ({local_user: {device_id_1: as_otk}}, [])
+ self.appservice_api.claim_client_keys.return_value = (
+ {local_user: {device_id_1: as_otk}},
+ [],
)
# Claim OTKs, which should return the OTK from the appservice and the
@@ -1234,8 +1229,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.assertEqual(fallback_res, ["alg1"])
# Finally, return only the fallback key from the appservice.
- self.appservice_api.claim_client_keys.return_value = make_awaitable(
- ({local_user: {device_id_1: as_fallback_key}}, [])
+ self.appservice_api.claim_client_keys.return_value = (
+ {local_user: {device_id_1: as_fallback_key}},
+ [],
)
# Claim OTKs, which will return only the fallback key from the database.
@@ -1350,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
)
# Setup a response.
- self.appservice_api.query_keys.return_value = make_awaitable(
- {
- "device_keys": {
- local_user: {device_2: device_key_2b, device_3: device_key_3}
- }
+ self.appservice_api.query_keys.return_value = {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
}
- )
+ }
# Request all devices.
res = self.get_success(self.handler.query_local_devices({local_user: None}))
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 5f11d5df11..21d63ab1f2 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -14,7 +14,7 @@
import logging
from typing import Collection, Optional, cast
from unittest import TestCase
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
@@ -40,7 +40,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
logger = logging.getLogger(__name__)
@@ -370,15 +370,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# We mock out the FederationClient.backfill method, to pretend that a remote
# server has returned our fake event.
- federation_client_backfill_mock = Mock(return_value=make_awaitable([event]))
- self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment]
+ federation_client_backfill_mock = AsyncMock(return_value=[event])
+ self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign]
# We also mock the persist method with a side effect of itself. This allows us
# to track when it has been called while preserving its function.
persist_events_and_notify_mock = Mock(
side_effect=self.hs.get_federation_event_handler().persist_events_and_notify
)
- self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment]
+ self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign]
persist_events_and_notify_mock
)
@@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
},
RoomVersions.V10,
)
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- "example.com",
- membership_event,
- RoomVersions.V10,
- )
+ mock_make_membership_event = AsyncMock(
+ return_value=(
+ "example.com",
+ membership_event,
+ RoomVersions.V10,
)
)
- mock_send_join = Mock(
- return_value=make_awaitable(
- SendJoinResult(
- membership_event,
- "example.com",
- state=[
- EVENT_CREATE,
- EVENT_CREATOR_MEMBERSHIP,
- EVENT_INVITATION_MEMBERSHIP,
- ],
- auth_chain=[
- EVENT_CREATE,
- EVENT_CREATOR_MEMBERSHIP,
- EVENT_INVITATION_MEMBERSHIP,
- ],
- partial_state=True,
- servers_in_room={"example.com"},
- )
+ mock_send_join = AsyncMock(
+ return_value=SendJoinResult(
+ membership_event,
+ "example.com",
+ state=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ auth_chain=[
+ EVENT_CREATE,
+ EVENT_CREATOR_MEMBERSHIP,
+ EVENT_INVITATION_MEMBERSHIP,
+ ],
+ partial_state=True,
+ servers_in_room={"example.com"},
)
)
diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 23f1b33b2f..70e6a7e142 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -35,7 +35,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
+ self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
+ self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
+ self.mock_federation_transport_client.get_event = mock.AsyncMock()
+ self.mock_federation_transport_client.backfill = mock.AsyncMock()
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
)
@@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# we expect an outbound request to /state_ids, so stub that out
- self.mock_federation_transport_client.get_room_state_ids.return_value = (
- make_awaitable(
- {
- "pdu_ids": [e.event_id for e in state_at_prev_event],
- "auth_chain_ids": [],
- }
- )
- )
+ self.mock_federation_transport_client.get_room_state_ids.return_value = {
+ "pdu_ids": [e.event_id for e in state_at_prev_event],
+ "auth_chain_ids": [],
+ }
# we also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = (
- make_awaitable(
- StateRequestResponse(auth_events=[], state=state_at_prev_event)
- )
+ StateRequestResponse(auth_events=[], state=state_at_prev_event)
)
# we have to bump the clock a bit, to keep the retry logic in
@@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room_version = self.get_success(main_store.get_room_version(room_id))
# We expect an outbound request to /state_ids, so stub that out
- self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
- {
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- "pdu_ids": [],
- "auth_chain_ids": [],
- }
- )
+ self.mock_federation_transport_client.get_room_state_ids.return_value = {
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ "pdu_ids": [],
+ "auth_chain_ids": [],
+ }
+
# We also expect an outbound request to /state
- self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
- StateRequestResponse(
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- auth_events=[],
- state=[],
- )
+ self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
)
pulled_event = make_event_from_dict(
@@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
- self.mock_federation_transport_client.backfill.return_value = make_awaitable(
- {
- "origin": self.OTHER_SERVER_NAME,
- "origin_server_ts": 123,
- "pdus": [
- # This is one of the important aspects of this test: we include
- # `pulled_event_without_signatures` so it fails the signature check
- # when we filter down the backfill response down to events which
- # have valid signatures in
- # `_check_sigs_and_hash_for_pulled_events_and_fetch`
- pulled_event_without_signatures.get_pdu_json(),
- # Then later when we process this valid signature event, when we
- # fetch the missing `prev_event`s, we want to make sure that we
- # backoff and don't try and fetch `pulled_event_without_signatures`
- # again since we know it just had an invalid signature.
- pulled_event.get_pdu_json(),
- ],
- }
- )
+ self.mock_federation_transport_client.backfill.return_value = {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
# Keep track of the count and make sure we don't make any of these requests
event_endpoint_requested_count = 0
@@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
- self.mock_federation_transport_client.backfill.return_value = make_awaitable(
- {
- "origin": self.OTHER_SERVER_NAME,
- "origin_server_ts": 123,
- "pdus": [
- pulled_event.get_pdu_json(),
- ],
- }
- )
+ self.mock_federation_transport_client.backfill.return_value = {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ pulled_event.get_pdu_json(),
+ ],
+ }
# The function under test: try to backfill and process the pulled event
with LoggingContext("test"):
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index 9691d66b48..1c5897c84e 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -46,18 +46,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self._persist_event_storage_controller = persistence
self.user_id = self.register_user("tester", "foobar")
- self.access_token = self.login("tester", "foobar")
- self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
-
- info = self.get_success(
- self.hs.get_datastores().main.get_user_by_access_token(
- self.access_token,
- )
- )
- assert info is not None
- self.token_id = info.token_id
+ device_id = "dev-1"
+ access_token = self.login("tester", "foobar", device_id=device_id)
+ self.room_id = self.helper.create_room_as(self.user_id, tok=access_token)
- self.requester = create_requester(self.user_id, access_token_id=self.token_id)
+ self.requester = create_requester(self.user_id, device_id=device_id)
def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]:
# Create a member event we can use as an auth_event
diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py
index b891e84690..9152694653 100644
--- a/tests/handlers/test_oauth_delegation.py
+++ b/tests/handlers/test_oauth_delegation.py
@@ -39,7 +39,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
from tests.unittest import HomeserverTestCase, skip_unless
from tests.utils import mock_getRawHeaders
@@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_inactive_token(self) -> None:
"""The handler should return a 403 where the token is inactive."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": False},
@@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_scope(self) -> None:
"""The handler should return a 403 where no scope is given."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True},
@@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_no_subject(self) -> None:
"""The handler should return a 500 when no subject is present."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
@@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_no_user_scope(self) -> None:
"""The handler should return a 500 when no subject is present."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_not_user(self) -> None:
"""The handler should raise when the scope has admin right but not user."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin(self) -> None:
"""The handler should return a requester with admin rights."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_admin_highest_privilege(self) -> None:
"""The handler should resolve to the most permissive scope."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user(self) -> None:
"""The handler should return a requester with normal user rights."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
"""The handler should return a requester with normal user rights
and an user ID matching the one specified in query param `user_id`"""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_user_with_device(self) -> None:
"""The handler should return a requester with normal user rights and a device ID."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_multiple_devices(self) -> None:
"""The handler should raise an error if multiple devices are found in the scope."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_not_allowed(self) -> None:
"""The handler should return an insufficient scope error."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_active_guest_allowed(self) -> None:
"""The handler should return a requester with guest user rights and a device ID."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
request.requestHeaders.getRawHeaders = mock_getRawHeaders()
# The introspection endpoint is returning an error.
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse(code=500, body=b"Internal Server Error")
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint request fails.
- self.http_client.request = simple_async_mock(raises=Exception())
+ self.http_client.request = AsyncMock(side_effect=Exception())
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return a JSON object.
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200, payload=["this is an array", "not an object"]
)
@@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(error.value.code, 503)
# The introspection endpoint does not return valid JSON.
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse(code=200, body=b"this is not valid JSON")
)
error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
@@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_introspection_token_cache(self) -> None:
access_token = "open_sesame"
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={"active": "true", "scope": "guest", "jti": access_token},
@@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
# test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
# token with a soon-to-expire `exp` field to the cache
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
@@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
def test_cross_signing(self) -> None:
"""Try uploading device keys with OAuth delegation enabled."""
- self.http_client.request = simple_async_mock(
+ self.http_client.request = AsyncMock(
return_value=FakeResponse.json(
code=200,
payload={
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 0a8bae54fb..e797aaae00 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
# limitations under the License.
import os
from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
-from unittest.mock import ANY, Mock, patch
+from unittest.mock import ANY, AsyncMock, Mock, patch
from urllib.parse import parse_qs, urlparse
import pymacaroons
@@ -28,7 +28,7 @@ from synapse.util import Clock
from synapse.util.macaroons import get_value_from_macaroon
from synapse.util.stringutils import random_string
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
from tests.unittest import HomeserverTestCase, override_config
@@ -157,15 +157,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
- sso_handler.render_error = self.render_error # type: ignore[assignment]
+ sso_handler.render_error = self.render_error # type: ignore[method-assign]
# Reduce the number of attempts when generating MXIDs.
sso_handler._MAP_USERNAME_RETRIES = 3
auth_handler = hs.get_auth_handler()
# Mock the complete SSO login method.
- self.complete_sso_login = simple_async_mock()
- auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment]
+ self.complete_sso_login = AsyncMock()
+ auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign]
return hs
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 394006f5f3..11ec8c7f11 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -16,7 +16,7 @@
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Type, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -32,7 +32,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
# Login flows we expect to appear in the list after the normal ones.
@@ -187,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = make_awaitable(True)
+ mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -209,13 +208,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"""UI Auth should delegate correctly to the password provider"""
# log in twice, to get two devices
- mock_password_provider.check_password.return_value = make_awaitable(True)
+ mock_password_provider.check_password = AsyncMock(return_value=True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
- mock_password_provider.check_password.return_value = make_awaitable(False)
+ mock_password_provider.check_password = AsyncMock(return_value=False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
@@ -229,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
- mock_password_provider.check_password.return_value = make_awaitable(True)
+ mock_password_provider.check_password = AsyncMock(return_value=True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -243,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = make_awaitable(False)
+ mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
@@ -260,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# have the auth provider deny the request
- mock_password_provider.check_password.return_value = make_awaitable(False)
+ mock_password_provider.check_password = AsyncMock(return_value=False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
@@ -303,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = make_awaitable(False)
+ mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -325,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# allow login via the auth provider
- mock_password_provider.check_password.return_value = make_awaitable(True)
+ mock_password_provider.check_password = AsyncMock(return_value=True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
@@ -342,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
- mock_password_provider.check_password.return_value = make_awaitable(False)
+ mock_password_provider.check_password = AsyncMock(return_value=False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
@@ -396,9 +395,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:test", None)
- )
+ mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
self.assertEqual("@user:test", channel.json_body["user_id"])
@@ -447,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:test", None)
- )
+ mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None))
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@@ -460,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# and finally, succeed
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@localuser:test", None)
+ mock_password_provider.check_auth = AsyncMock(
+ return_value=("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@@ -478,10 +473,10 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self) -> None:
- callback = Mock(return_value=make_awaitable(None))
+ callback = AsyncMock(return_value=None)
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@user:test", callback)
+ mock_password_provider.check_auth = AsyncMock(
+ return_value=("@user:test", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -616,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
- mock_password_provider.check_auth.return_value = make_awaitable(
- ("@localuser:test", None)
+ mock_password_provider.check_auth = AsyncMock(
+ return_value=("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@@ -835,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
username: The username to use for the test.
registration: Whether to test with registration URLs.
"""
- self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment]
- return_value=make_awaitable(0),
+ self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign]
+ return_value=0
)
- m = Mock(return_value=make_awaitable(False))
+ m = AsyncMock(return_value=False)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
self.register_user(username, "password")
@@ -869,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
m.assert_called_once_with("email", "foo@test.com", registration)
- m = Mock(return_value=make_awaitable(True))
+ m = AsyncMock(return_value=True)
self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m]
channel = self.make_request(
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 1aebcc16ad..88a16193a3 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -524,6 +524,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_id = f"@test:{self.hs.config.server.server_name}"
+ self.device_id = "dev-1"
# Move the reactor to the initial time.
self.reactor.advance(1000)
@@ -608,7 +609,10 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2)
self.get_success(
presence_handler.user_syncing(
- self.user_id, sync_state != PresenceState.OFFLINE, sync_state
+ self.user_id,
+ self.device_id,
+ sync_state != PresenceState.OFFLINE,
+ sync_state,
)
)
@@ -632,6 +636,7 @@ class PresenceHandlerInitTestCase(unittest.HomeserverTestCase):
class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
user_id = "@test:server"
user_id_obj = UserID.from_string(user_id)
+ device_id = "dev-1"
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.presence_handler = hs.get_presence_handler()
@@ -641,13 +646,20 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
"""Test that if an external process doesn't update the records for a while
we time out their syncing users presence.
"""
- process_id = "1"
- # Notify handler that a user is now syncing.
+ # Create a worker and use it to handle /sync traffic instead.
+ # This is used to test that presence changes get replicated from workers
+ # to the main process correctly.
+ worker_to_sync_against = self.make_worker_hs(
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
+ )
+ worker_presence_handler = worker_to_sync_against.get_presence_handler()
+
self.get_success(
- self.presence_handler.update_external_syncs_row(
- process_id, self.user_id, True, self.clock.time_msec()
- )
+ worker_presence_handler.user_syncing(
+ self.user_id, self.device_id, True, PresenceState.ONLINE
+ ),
+ by=0.1,
)
# Check that if we wait a while without telling the handler the user has
@@ -701,7 +713,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as offline
self.get_success(
self.presence_handler.set_state(
- self.user_id_obj, {"presence": PresenceState.OFFLINE}
+ self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE}
)
)
@@ -733,7 +745,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# Mark user as online again
self.get_success(
self.presence_handler.set_state(
- self.user_id_obj, {"presence": PresenceState.ONLINE}
+ self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE}
)
)
@@ -762,7 +774,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
self.presence_handler.user_syncing(
- self.user_id, False, PresenceState.ONLINE
+ self.user_id, self.device_id, False, PresenceState.ONLINE
)
)
@@ -779,7 +791,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
self.get_success(
- self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+ self.presence_handler.user_syncing(
+ self.user_id, self.device_id, True, PresenceState.ONLINE
+ )
)
state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -793,7 +807,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg)
self.get_success(
- self.presence_handler.user_syncing(self.user_id, True, PresenceState.ONLINE)
+ self.presence_handler.user_syncing(
+ self.user_id, self.device_id, True, PresenceState.ONLINE
+ )
)
state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
@@ -820,7 +836,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# This is used to test that presence changes get replicated from workers
# to the main process correctly.
worker_to_sync_against = self.make_worker_hs(
- "synapse.app.generic_worker", {"worker_name": "presence_writer"}
+ "synapse.app.generic_worker", {"worker_name": "synchrotron"}
)
# Set presence to BUSY
@@ -831,8 +847,9 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# /presence/*.
self.get_success(
worker_to_sync_against.get_presence_handler().user_syncing(
- self.user_id, True, PresenceState.ONLINE
- )
+ self.user_id, self.device_id, True, PresenceState.ONLINE
+ ),
+ by=0.1,
)
# Check against the main process that the user's presence did not change.
@@ -840,6 +857,21 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
# we should still be busy
self.assertEqual(state.state, PresenceState.BUSY)
+ # Advance such that the device would be discarded if it was not busy,
+ # then pump so _handle_timeouts function to called.
+ self.reactor.advance(IDLE_TIMER / 1000)
+ self.reactor.pump([5])
+
+ # The account should still be busy.
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+ self.assertEqual(state.state, PresenceState.BUSY)
+
+ # Ensure that a /presence call can set the user *off* busy.
+ self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg)
+
+ state = self.get_success(self.presence_handler.get_state(self.user_id_obj))
+ self.assertEqual(state.state, PresenceState.ONLINE)
+
def _set_presencestate_with_status_msg(
self, state: str, status_msg: Optional[str]
) -> None:
@@ -852,6 +884,7 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase):
self.get_success(
self.presence_handler.set_state(
self.user_id_obj,
+ self.device_id,
{"presence": state, "status_msg": status_msg},
)
)
@@ -876,8 +909,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
now_token = self.queue.get_current_token(self.instance_name)
@@ -913,11 +952,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
now_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
rows, upto_token, limited = self.get_success(
self.queue.get_replication_rows("master", prev_token, now_token, 10)
@@ -956,8 +1001,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
self.reactor.advance(10 * 60 * 1000)
@@ -972,8 +1023,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
now_token = self.queue.get_current_token(self.instance_name)
@@ -1000,11 +1057,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
self.reactor.advance(2 * 60 * 1000)
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
self.reactor.advance(4 * 60 * 1000)
@@ -1020,8 +1083,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase):
prev_token = self.queue.get_current_token(self.instance_name)
- self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2"))
- self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ self.get_success(
+ self.queue.send_presence_to_destinations(
+ (state1, state2), ("dest1", "dest2")
+ )
+ )
+ self.get_success(
+ self.queue.send_presence_to_destinations((state3,), ("dest3",))
+ )
now_token = self.queue.get_current_token(self.instance_name)
@@ -1093,7 +1162,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Mark test2 as online, test will be offline with a last_active of 0
self.get_success(
self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ UserID.from_string("@test2:server"),
+ "dev-1",
+ {"presence": PresenceState.ONLINE},
)
)
self.reactor.pump([0]) # Wait for presence updates to be handled
@@ -1140,7 +1211,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Mark test as online
self.get_success(
self.presence_handler.set_state(
- UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE}
+ UserID.from_string("@test:server"),
+ "dev-1",
+ {"presence": PresenceState.ONLINE},
)
)
@@ -1148,7 +1221,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
# Note we don't join them to the room yet
self.get_success(
self.presence_handler.set_state(
- UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE}
+ UserID.from_string("@test2:server"),
+ "dev-1",
+ {"presence": PresenceState.ONLINE},
)
)
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index ec2f5d30be..f9b292b9ec 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Awaitable, Callable, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
class ProfileTestCase(unittest.HomeserverTestCase):
@@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
servlets = [admin.register_servlets]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- self.mock_federation = Mock()
+ self.mock_federation = AsyncMock()
self.mock_registry = Mock()
self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {}
@@ -135,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase):
)
def test_get_other_name(self) -> None:
- self.mock_federation.make_query.return_value = make_awaitable(
- {"displayname": "Alice"}
- )
+ self.mock_federation.make_query.return_value = {"displayname": "Alice"}
displayname = self.get_success(self.handler.get_displayname(self.alice))
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 54eeec228e..e9fbf32c7c 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Collection, List, Optional, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -38,7 +38,6 @@ from synapse.types import (
)
from synapse.util import Clock
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
from tests.utils import mock_getRawHeaders
@@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self) -> None:
- self.store.count_monthly_users = Mock( # type: ignore[assignment]
- return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
+ self.store.count_monthly_users = AsyncMock( # type: ignore[method-assign]
+ return_value=self.hs.config.server.max_mau_value - 1
)
# Ensure does not throw exception
self.get_success(self.get_or_create_user(self.requester, "c", "User"))
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self) -> None:
- self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.lots_of_users)
- )
+ self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
ResourceLimitError,
)
- self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.server.max_mau_value)
+ self.store.get_monthly_active_count = AsyncMock(
+ return_value=self.hs.config.server.max_mau_value
)
self.get_failure(
self.get_or_create_user(self.requester, "b", "display_name"),
@@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_register_mau_blocked(self) -> None:
- self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.lots_of_users)
- )
+ self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
)
- self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(self.hs.config.server.max_mau_value)
+ self.store.get_monthly_active_count = AsyncMock(
+ return_value=self.hs.config.server.max_mau_value
)
self.get_failure(
self.handler.register_user(localpart="local_part"), ResourceLimitError
@@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"auto_join_rooms": ["#room:test"]})
def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.is_real_user = Mock(return_value=make_awaitable(False))
+ self.store.is_real_user = AsyncMock(return_value=False)
user_id = self.get_success(self.handler.register_user(localpart="support"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
@@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None:
room_alias_str = "#room:test"
- self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment]
- self.store.is_real_user = Mock(return_value=make_awaitable(True))
+ self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[method-assign]
+ self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
directory_handler = self.hs.get_directory_handler()
@@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user(
self,
) -> None:
- self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment]
- self.store.is_real_user = Mock(return_value=make_awaitable(True))
+ self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[method-assign]
+ self.store.is_real_user = AsyncMock(return_value=True)
user_id = self.get_success(self.handler.register_user(localpart="real"))
rooms = self.get_success(self.store.get_rooms_for_user(user_id))
self.assertEqual(len(rooms), 0)
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index 41199ffa29..3e28117e2c 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -1,4 +1,4 @@
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -16,7 +16,6 @@ from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.test_utils import make_awaitable
from tests.unittest import (
FederatingHomeserverTestCase,
HomeserverTestCase,
@@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
None,
)
- mock_make_membership_event = Mock(
- return_value=make_awaitable(
- (
- self.OTHER_SERVER_NAME,
- join_event,
- self.hs.config.server.default_room_version,
- )
+ mock_make_membership_event = AsyncMock(
+ return_value=(
+ self.OTHER_SERVER_NAME,
+ join_event,
+ self.hs.config.server.default_room_version,
)
)
- mock_send_join = Mock(
- return_value=make_awaitable(
- SendJoinResult(
- join_event,
- self.OTHER_SERVER_NAME,
- state=[create_event],
- auth_chain=[create_event],
- partial_state=False,
- servers_in_room=frozenset(),
- )
+ mock_send_join = AsyncMock(
+ return_value=SendJoinResult(
+ join_event,
+ self.OTHER_SERVER_NAME,
+ state=[create_event],
+ auth_chain=[create_event],
+ partial_state=False,
+ servers_in_room=frozenset(),
)
)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index b5c772a7ae..00f4e181e8 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
import attr
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
-from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
# Check if we have the dependencies to run the tests.
@@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# Map a user via SSO.
saml_response = FakeAuthnResponse(
@@ -206,11 +205,11 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# mock out the error renderer too
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
request = _mock_request()
@@ -227,9 +226,9 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler and error renderer
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
sso_handler = self.hs.get_sso_handler()
- sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment]
+ sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign]
# register a user to occupy the first-choice MXID
store = self.hs.get_datastores().main
@@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# stub out the auth handler
auth_handler = self.hs.get_auth_handler()
- auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment]
+ auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign]
# The response doesn't have the proper userGroup or department.
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py
index 8b6e4a40b6..a066745d70 100644
--- a/tests/handlers/test_send_email.py
+++ b/tests/handlers/test_send_email.py
@@ -13,19 +13,40 @@
# limitations under the License.
-from typing import Callable, List, Tuple
+from typing import Callable, List, Tuple, Type, Union
+from unittest.mock import patch
from zope.interface import implementer
from twisted.internet import defer
-from twisted.internet.address import IPv4Address
+from twisted.internet._sslverify import ClientTLSOptions
+from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import ensureDeferred
+from twisted.internet.interfaces import IProtocolFactory
+from twisted.internet.ssl import ContextFactory
from twisted.mail import interfaces, smtp
from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase, override_config
+def TestingESMTPTLSClientFactory(
+ contextFactory: ContextFactory,
+ _connectWrapped: bool,
+ wrappedProtocol: IProtocolFactory,
+) -> IProtocolFactory:
+ """We use this to pass through in testing without using TLS, but
+ saving the context information to check that it would have happened.
+
+ Note that this is what the MemoryReactor does on connectSSL.
+ It only saves the contextFactory, but starts the connection with the
+ underlying Factory.
+ See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""
+
+ wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
+ return wrappedProtocol
+
+
@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self) -> None:
@@ -75,7 +96,13 @@ class _DummyMessage:
pass
-class SendEmailHandlerTestCase(HomeserverTestCase):
+class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
+ ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.reactor.lookups["localhost"] = "127.0.0.1"
+
def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
@@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
- self.assertEqual(host, "localhost")
+ self.assertEqual(host, self.reactor.lookups["localhost"])
self.assertEqual(port, 25)
# wire it up to an SMTP server
@@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
FakeTransport(
client_protocol,
self.reactor,
- peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ peer_address=self.ip_class(
+ "TCP", self.reactor.lookups["localhost"], 1234
+ ),
)
)
@@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+ @patch(
+ "synapse.handlers.send_email.TLSMemoryBIOFactory",
+ TestingESMTPTLSClientFactory,
+ )
@override_config(
{
"email": {
@@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
)
)
# there should be an attempt to connect to localhost:465
- self.assertEqual(len(self.reactor.sslClients), 1)
+ self.assertEqual(len(self.reactor.tcpClients), 1)
(
host,
port,
client_factory,
- contextFactory,
_timeout,
_bindAddress,
- ) = self.reactor.sslClients[0]
- self.assertEqual(host, "localhost")
+ ) = self.reactor.tcpClients[0]
+ self.assertEqual(host, self.reactor.lookups["localhost"])
self.assertEqual(port, 465)
+ # We need to make sure that TLS is happenning
+ self.assertIsInstance(
+ client_factory._wrappedFactory._testingContextFactory,
+ ClientTLSOptions,
+ )
+ # And since we use endpoints, they go through reactor.connectTCP
+ # which works differently to connectSSL on the testing reactor
# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
@@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
FakeTransport(
client_protocol,
self.reactor,
- peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
+ peer_address=self.ip_class(
+ "TCP", self.reactor.lookups["localhost"], 1234
+ ),
)
)
@@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase):
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "foo@bar.com")
self.assertIn(b"Subject: test subject", msg)
+
+
+class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
+ ip_class = IPv6Address
+
+ def setUp(self) -> None:
+ super().setUp()
+ self.reactor.lookups["localhost"] = "::1"
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 9f035a02dc..948d04fc32 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
-from unittest.mock import MagicMock, Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -29,7 +29,6 @@ from synapse.util import Clock
import tests.unittest
import tests.utils
-from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
mocked_get_prev_events = patch.object(
self.hs.get_datastores().main,
"get_prev_events_for_room",
- new_callable=MagicMock,
- return_value=make_awaitable([last_room_creation_event_id]),
+ new_callable=AsyncMock,
+ return_value=[last_room_creation_event_id],
)
with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token)
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 5da1d95f0b..95106ec8f3 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,7 +15,7 @@
import json
from typing import Dict, List, Set
-from unittest.mock import ANY, Mock, call
+from unittest.mock import ANY, AsyncMock, Mock, call
from netaddr import IPSet
@@ -33,7 +33,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import ThreadedMemoryReactorClock
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
# Some local users to test with
@@ -74,11 +73,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
- mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
+ mock_keyring.verify_json_for_server = AsyncMock(return_value=True)
# we mock out the federation client too
- self.mock_federation_client = Mock(spec=["put_json"])
- self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
+ self.mock_federation_client = AsyncMock(spec=["put_json"])
+ self.mock_federation_client.put_json.return_value = (200, "OK")
self.mock_federation_client.agent = MatrixFederationAgent(
reactor,
tls_client_options_factory=None,
@@ -121,20 +120,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
- self.datastore.get_destination_retry_timings = Mock(
- return_value=make_awaitable(None)
+ self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign]
+ return_value=(0, [])
)
- self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment]
- return_value=make_awaitable((0, []))
+ self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[method-assign]
+ return_value=None
)
- self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment]
- return_value=make_awaitable(None)
- )
-
- self.datastore.get_received_txn_response = Mock( # type: ignore[assignment]
- return_value=make_awaitable(None)
+ self.datastore.get_received_txn_response = AsyncMock( # type: ignore[method-assign]
+ return_value=None
)
self.room_members: List[UserID] = []
@@ -146,25 +141,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
raise AuthError(401, "User is not in the room")
return None
- hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment]
+ hs.get_auth().check_user_in_room = Mock( # type: ignore[method-assign]
side_effect=check_user_in_room
)
async def check_host_in_room(room_id: str, server_name: str) -> bool:
return room_id == ROOM_ID
- hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment]
+ hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[method-assign]
side_effect=check_host_in_room
)
async def get_current_hosts_in_room(room_id: str) -> Set[str]:
return {member.domain for member in self.room_members}
- hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
+ hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[method-assign]
side_effect=get_current_hosts_in_room
)
- hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment]
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[method-assign]
side_effect=get_current_hosts_in_room
)
@@ -173,27 +168,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room)
- self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment]
- side_effect=(
- # we deliberately return a non-None stream pos to avoid
- # doing an initial_sync
- lambda: make_awaitable(1)
- )
+ self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[method-assign]
+ # we deliberately return a non-None stream pos to avoid
+ # doing an initial_sync
+ return_value=1
)
- self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment]
+ self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign]
- self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment]
- side_effect=lambda: 0
+ self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign]
+ return_value=0
)
- self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment]
- side_effect=lambda *args, **kargs: make_awaitable(([], 0))
+ self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign]
+ return_value=([], 0)
)
- self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment]
- side_effect=lambda *args, **kargs: make_awaitable(None)
+ self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign]
+ return_value=None
)
- self.datastore.set_received_txn_response = Mock( # type: ignore[assignment]
- side_effect=lambda *args, **kwargs: make_awaitable(None)
+ self.datastore.set_received_txn_response = AsyncMock( # type: ignore[method-assign]
+ return_value=None
)
def test_started_typing_local(self) -> None:
@@ -256,8 +249,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
json_data_callback=ANY,
long_retries=True,
- backoff_on_404=True,
try_trailing_slash_on_400=True,
+ backoff_on_all_error_codes=True,
)
def test_started_typing_remote_recv(self) -> None:
@@ -371,7 +364,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
),
json_data_callback=ANY,
long_retries=True,
- backoff_on_404=True,
+ backoff_on_all_error_codes=True,
try_trailing_slash_on_400=True,
)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 430209705e..b5f15aa7d4 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Tuple
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
from urllib.parse import quote
from twisted.test.proto_helpers import MemoryReactor
@@ -30,7 +30,7 @@ from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@@ -471,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
+ mock_remove_from_user_dir = AsyncMock(return_value=None)
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
|