diff options
Diffstat (limited to 'tests')
145 files changed, 8165 insertions, 3180 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6e36e73f0d..e00d7215df 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import pymacaroons from twisted.test.proto_helpers import MemoryReactor -from synapse.api.auth import Auth +from synapse.api.auth.internal import InternalAuth from synapse.api.auth_blocking import AuthBlocking from synapse.api.constants import UserTypes from synapse.api.errors import ( @@ -35,7 +35,6 @@ from synapse.types import Requester, UserID from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -48,7 +47,7 @@ class AuthTestCase(unittest.HomeserverTestCase): # have been called by the HomeserverTestCase machinery. hs.datastores.main = self.store # type: ignore[union-attr] hs.get_auth_handler().store = self.store - self.auth = Auth(hs) + self.auth = InternalAuth(hs) # AuthBlocking reads from the hs' config on initialization. We need to # modify its config instead of the hs' @@ -60,15 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase): # this is overridden for the appservice tests self.store.get_app_service_by_token = Mock(return_value=None) - self.store.insert_client_ip = simple_async_mock(None) - self.store.is_support_user = simple_async_mock(False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.is_support_user = AsyncMock(return_value=False) def test_get_user_by_req_user_valid_token(self) -> None: user_info = TokenLookupResult( user_id=self.test_user, token_id=5, device_id="device" ) - self.store.get_user_by_access_token = simple_async_mock(user_info) - self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -77,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -90,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_user_missing_token(self) -> None: user_info = TokenLookupResult(user_id=self.test_user, token_id=5) - self.store.get_user_by_access_token = simple_async_mock(user_info) + self.store.get_user_by_access_token = AsyncMock(return_value=user_info) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase): token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "192.168.10.10" @@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ip_range_whitelist=IPSet(["192.168/16"]), ) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "131.111.8.42" @@ -157,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_appservice_bad_token(self) -> None: self.store.get_app_service_by_token = Mock(return_value=None) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.args[b"access_token"] = [self.test_token] @@ -171,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req_appservice_missing_token(self) -> None: app_service = Mock(token="foobar", url="a_url", sender=self.test_user) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -188,9 +188,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) - # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + + class FakeUserInfo: + is_guest = False + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -209,7 +212,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) app_service.is_interested_in_user = Mock(return_value=False) self.store.get_app_service_by_token = Mock(return_value=app_service) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -233,10 +236,10 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a truth-y value - self.store.get_device = simple_async_mock({"hidden": False}) + self.store.get_device = AsyncMock(return_value={"hidden": False}) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -265,10 +268,10 @@ class AuthTestCase(unittest.HomeserverTestCase): app_service.is_interested_in_user = Mock(return_value=True) self.store.get_app_service_by_token = Mock(return_value=app_service) # This just needs to return a truth-y value. - self.store.get_user_by_id = simple_async_mock({"is_guest": False}) - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False}) + self.store.get_user_by_access_token = AsyncMock(return_value=None) # This also needs to just return a falsey value - self.store.get_device = simple_async_mock(None) + self.store.get_device = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" @@ -282,8 +285,8 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE) def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None: - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -291,8 +294,9 @@ class AuthTestCase(unittest.HomeserverTestCase): token_used=True, ) ) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) + self.store.get_user_locked_status = AsyncMock(return_value=False) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -302,8 +306,8 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None: self.auth._track_puppeted_user_ips = True - self.store.get_user_by_access_token = simple_async_mock( - TokenLookupResult( + self.store.get_user_by_access_token = AsyncMock( + return_value=TokenLookupResult( user_id="@baldrick:matrix.org", device_id="device", token_id=5, @@ -311,8 +315,9 @@ class AuthTestCase(unittest.HomeserverTestCase): token_used=True, ) ) - self.store.insert_client_ip = simple_async_mock(None) - self.store.mark_access_token_as_used = simple_async_mock(None) + self.store.get_user_locked_status = AsyncMock(return_value=False) + self.store.insert_client_ip = AsyncMock(return_value=None) + self.store.mark_access_token_as_used = AsyncMock(return_value=None) request = Mock(args={}) request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] @@ -321,7 +326,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(self.store.insert_client_ip.call_count, 2) def test_get_user_from_macaroon(self) -> None: - self.store.get_user_by_access_token = simple_async_mock(None) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -339,8 +344,11 @@ class AuthTestCase(unittest.HomeserverTestCase): ) def test_get_guest_user_from_macaroon(self) -> None: - self.store.get_user_by_id = simple_async_mock({"is_guest": True}) - self.store.get_user_by_access_token = simple_async_mock(None) + class FakeUserInfo: + is_guest = True + + self.store.get_user_by_id = AsyncMock(return_value=FakeUserInfo()) + self.store.get_user_by_access_token = AsyncMock(return_value=None) user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( @@ -370,7 +378,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(lots_of_users) + self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) e = self.get_failure( self.auth_blocking.check_auth_blocking(), ResourceLimitError @@ -380,25 +388,27 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertEqual(e.value.code, 403) # Ensure does not throw an error - self.store.get_monthly_active_count = simple_async_mock(small_number_of_users) + self.store.get_monthly_active_count = AsyncMock( + return_value=small_number_of_users + ) self.get_success(self.auth_blocking.check_auth_blocking()) def test_blocking_mau__depending_on_user_type(self) -> None: self.auth_blocking._max_mau_value = 50 self.auth_blocking._limit_usage_by_mau = True - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Support users allowed self.get_success( self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT) ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Bots not allowed self.get_failure( self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT), ResourceLimitError, ) - self.store.get_monthly_active_count = simple_async_mock(100) + self.store.get_monthly_active_count = AsyncMock(return_value=100) # Real users not allowed self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError) @@ -409,9 +419,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = False - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -426,6 +436,7 @@ class AuthTestCase(unittest.HomeserverTestCase): access_token_id=None, device_id="FOOBAR", is_guest=False, + scope=set(), shadow_banned=False, app_service=appservice, authenticated_entity="@appservice:server", @@ -439,9 +450,9 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._track_appservice_user_ips = True - self.store.get_monthly_active_count = simple_async_mock(100) - self.store.user_last_seen_monthly_active = simple_async_mock() - self.store.is_trial_user = simple_async_mock() + self.store.get_monthly_active_count = AsyncMock(return_value=100) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) + self.store.is_trial_user = AsyncMock(return_value=False) appservice = ApplicationService( "abcd", @@ -456,6 +467,7 @@ class AuthTestCase(unittest.HomeserverTestCase): access_token_id=None, device_id="FOOBAR", is_guest=False, + scope=set(), shadow_banned=False, app_service=appservice, authenticated_entity="@appservice:server", @@ -468,7 +480,7 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_reserved_threepid(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._max_mau_value = 1 - self.store.get_monthly_active_count = simple_async_mock(2) + self.store.get_monthly_active_count = AsyncMock(return_value=2) threepid = {"medium": "email", "address": "reserved@server.com"} unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.auth_blocking._mau_limits_reserved_threepids = [threepid] diff --git a/tests/api/test_errors.py b/tests/api/test_errors.py new file mode 100644 index 0000000000..8e159029d9 --- /dev/null +++ b/tests/api/test_errors.py @@ -0,0 +1,43 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from synapse.api.errors import LimitExceededError + +from tests import unittest + + +class LimitExceededErrorTestCase(unittest.TestCase): + def test_key_appears_in_context_but_not_error_dict(self) -> None: + err = LimitExceededError("needle") + serialised = json.dumps(err.error_dict(None)) + self.assertIn("needle", err.debug_context) + self.assertNotIn("needle", serialised) + + # Create a sub-class to avoid mutating the class-level property. + class LimitExceededErrorHeaders(LimitExceededError): + include_retry_after_header = True + + def test_limit_exceeded_header(self) -> None: + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) + assert err.headers is not None + self.assertEqual(err.headers.get("Retry-After"), "1") + + def test_limit_exceeded_rounding(self) -> None: + err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001) + self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) + assert err.headers is not None + self.assertEqual(err.headers.get("Retry-After"), "4") diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py index 222449baac..868f0c6995 100644 --- a/tests/api/test_filtering.py +++ b/tests/api/test_filtering.py @@ -35,7 +35,6 @@ from tests.events.test_utils import MockEvent user_id = UserID.from_string("@test_user:test") user2_id = UserID.from_string("@test_user2:test") -user_localpart = "test_user" class FilteringTestCase(unittest.HomeserverTestCase): @@ -48,8 +47,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): invalid_filters: List[JsonDict] = [ # `account_data` must be a dictionary {"account_data": "Hello World"}, - # `event_fields` entries must not contain backslashes - {"event_fields": [r"\\foo"]}, # `event_format` must be "client" or "federation" {"event_format": "other"}, # `not_rooms` must contain valid room IDs @@ -114,10 +111,6 @@ class FilteringTestCase(unittest.HomeserverTestCase): "event_format": "client", "event_fields": ["type", "content", "sender"], }, - # a single backslash should be permitted (though it is debatable whether - # it should be permitted before anything other than `.`, and what that - # actually means) - # # (note that event_fields is implemented in # synapse.events.utils.serialize_event, and so whether this actually works # is tested elsewhere. We just want to check that it is allowed through the @@ -455,9 +448,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_presence(presence_states)) @@ -485,9 +476,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart + "2", filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user2_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_presence(presence_states)) @@ -504,9 +493,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): events = [event] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_room_state(events=events)) @@ -525,9 +512,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): events = [event] user_filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) results = self.get_success(user_filter.filter_room_state(events)) @@ -609,9 +594,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): user_filter_json, ( self.get_success( - self.datastore.get_user_filter( - user_localpart=user_localpart, filter_id=0 - ) + self.datastore.get_user_filter(user_id=user_id, filter_id=0) ) ), ) @@ -626,9 +609,7 @@ class FilteringTestCase(unittest.HomeserverTestCase): ) filter = self.get_success( - self.filtering.get_user_filter( - user_localpart=user_localpart, filter_id=filter_id - ) + self.filtering.get_user_filter(user_id=user_id, filter_id=filter_id) ) self.assertEqual(filter.get_filter_json(), user_filter_json) diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index fa6c1c02ce..a24638c9ef 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -1,5 +1,6 @@ from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.appservice import ApplicationService +from synapse.config.ratelimiting import RatelimitSettings from synapse.types import create_requester from tests import unittest @@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(None, key="test_id", _time_now_s=0) @@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=1, + ), ) allowed, time_allowed = self.get_success_or_raise( limiter.can_do_action(as_requester, _time_now_s=0) @@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # Shouldn't raise @@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) # First attempt should be allowed @@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=1, + cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1), ) self.get_success_or_raise( limiter.can_do_action(None, key="test_id_1", _time_now_s=0) @@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): ) ) - limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) + limiter = Ratelimiter( + store=store, + clock=self.clock, + cfg=RatelimitSettings("", per_second=0.1, burst_count=1), + ) # Shouldn't raise for _ in range(20): @@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + key="", + per_second=0.1, + burst_count=3, + ), ) # Test that 4 actions aren't allowed with a maximum burst of 3. allowed, time_allowed = self.get_success_or_raise( @@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings("", per_second=0.1, burst_count=3), ) def consume_at(time: float) -> bool: @@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe two actions, leaving room in the bucket for one more. @@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe three actions, filling up the bucket. @@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase): limiter = Ratelimiter( store=self.hs.get_datastores().main, clock=self.clock, - rate_hz=0.1, - burst_count=3, + cfg=RatelimitSettings( + "", + per_second=0.1, + burst_count=3, + ), ) # Observe four actions, exceeding the bucket. diff --git a/tests/app/test_homeserver_start.py b/tests/app/test_homeserver_start.py index 788c935537..0201933b04 100644 --- a/tests/app/test_homeserver_start.py +++ b/tests/app/test_homeserver_start.py @@ -25,7 +25,9 @@ class HomeserverAppStartTestCase(ConfigFileTestCase): # Add a blank line as otherwise the next addition ends up on a line with a comment self.add_lines_to_config([" "]) self.add_lines_to_config(["worker_app: test_worker_app"]) - + self.add_lines_to_config(["worker_log_config: /data/logconfig.config"]) + self.add_lines_to_config(["instance_map:"]) + self.add_lines_to_config([" main:", " host: 127.0.0.1", " port: 1234"]) # Ensure that starting master process with worker config raises an exception with self.assertRaises(ConfigError): synapse.app.homeserver.setup(["-c", self.config_file]) diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py index 6e0413400e..21c5309740 100644 --- a/tests/app/test_openid_listener.py +++ b/tests/app/test_openid_listener.py @@ -31,9 +31,7 @@ from tests.unittest import HomeserverTestCase class FederationReaderOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=GenericWorkerServer - ) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs def default_config(self) -> JsonDict: @@ -42,6 +40,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): # have to tell the FederationHandler not to try to access stuff that is only # in the primary store. conf["worker_app"] = "yes" + conf["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}} return conf @@ -90,9 +89,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase): @patch("synapse.app.homeserver.KeyResource", new=Mock()) class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=SynapseHomeServer - ) + hs = self.setup_test_homeserver(homeserver_to_use=SynapseHomeServer) return hs @parameterized.expand( diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index a860eedbcf..93af614def 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -4,7 +4,6 @@ from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.util import Clock -from tests import unittest from tests.server import ThreadedMemoryReactorClock from tests.unittest import HomeserverTestCase @@ -12,154 +11,6 @@ FIVE_MINUTES_IN_SECONDS = 300 ONE_DAY_IN_SECONDS = 86400 -class PhoneHomeTestCase(HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets_for_client_rest_resource, - room.register_servlets, - login.register_servlets, - ] - - # Override the retention time for the user_ips table because otherwise it - # gets pruned too aggressively for our R30 test. - @unittest.override_config({"user_ips_max_age": "365d"}) - def test_r30_minimum_usage(self) -> None: - """ - Tests the minimum amount of interaction necessary for the R30 metric - to consider a user 'retained'. - """ - - # Register a user, log it in, create a room and send a message - user_id = self.register_user("u1", "secret!") - access_token = self.login("u1", "secret!") - room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) - self.helper.send(room_id, "message", tok=access_token) - - # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - # Advance 30 days (+ 1 second, because strict inequality causes issues if we are - # bang on 30 days later). - self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) - - # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - # Send a message (this counts as activity) - self.helper.send(room_id, "message2", tok=access_token) - - # We have to wait some time for _update_client_ips_batch to get - # called and update the user_ips table. - self.reactor.advance(2 * 60 * 60) - - # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "unknown": 1}) - - # Advance 29 days. The user has now not posted for 29 days. - self.reactor.advance(29 * ONE_DAY_IN_SECONDS) - - # The user is still counted. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "unknown": 1}) - - # Advance another day. The user has now not posted for 30 days. - self.reactor.advance(ONE_DAY_IN_SECONDS) - - # The user is now no longer counted in R30. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - def test_r30_minimum_usage_using_default_config(self) -> None: - """ - Tests the minimum amount of interaction necessary for the R30 metric - to consider a user 'retained'. - - N.B. This test does not override the `user_ips_max_age` config setting, - which defaults to 28 days. - """ - - # Register a user, log it in, create a room and send a message - user_id = self.register_user("u1", "secret!") - access_token = self.login("u1", "secret!") - room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) - self.helper.send(room_id, "message", tok=access_token) - - # Check the R30 results do not count that user. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - # Advance 30 days (+ 1 second, because strict inequality causes issues if we are - # bang on 30 days later). - self.reactor.advance(30 * ONE_DAY_IN_SECONDS + 1) - - # (Make sure the user isn't somehow counted by this point.) - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - # Send a message (this counts as activity) - self.helper.send(room_id, "message2", tok=access_token) - - # We have to wait some time for _update_client_ips_batch to get - # called and update the user_ips table. - self.reactor.advance(2 * 60 * 60) - - # *Now* the user is counted. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "unknown": 1}) - - # Advance 27 days. The user has now not posted for 27 days. - self.reactor.advance(27 * ONE_DAY_IN_SECONDS) - - # The user is still counted. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "unknown": 1}) - - # Advance another day. The user has now not posted for 28 days. - self.reactor.advance(ONE_DAY_IN_SECONDS) - - # The user is now no longer counted in R30. - # (This is because the user_ips table has been pruned, which by default - # only preserves the last 28 days of entries.) - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - def test_r30_user_must_be_retained_for_at_least_a_month(self) -> None: - """ - Tests that a newly-registered user must be retained for a whole month - before appearing in the R30 statistic, even if they post every day - during that time! - """ - # Register a user and send a message - user_id = self.register_user("u1", "secret!") - access_token = self.login("u1", "secret!") - room_id = self.helper.create_room_as(room_creator=user_id, tok=access_token) - self.helper.send(room_id, "message", tok=access_token) - - # Check the user does not contribute to R30 yet. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 0}) - - for _ in range(30): - # This loop posts a message every day for 30 days - self.reactor.advance(ONE_DAY_IN_SECONDS) - self.helper.send(room_id, "I'm still here", tok=access_token) - - # Notice that the user *still* does not contribute to R30! - r30_results = self.get_success( - self.hs.get_datastores().main.count_r30_users() - ) - self.assertEqual(r30_results, {"all": 0}) - - self.reactor.advance(ONE_DAY_IN_SECONDS) - self.helper.send(room_id, "Still here!", tok=access_token) - - # *Now* the user appears in R30. - r30_results = self.get_success(self.hs.get_datastores().main.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "unknown": 1}) - - class PhoneHomeR30V2TestCase(HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -175,7 +26,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): def make_homeserver( self, reactor: ThreadedMemoryReactorClock, clock: Clock ) -> HomeServer: - hs = super(PhoneHomeR30V2TestCase, self).make_homeserver(reactor, clock) + hs = super().make_homeserver(reactor, clock) # We don't want our tests to actually report statistics, so check # that it's not enabled @@ -363,11 +214,6 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): r30_results, {"all": 0, "android": 0, "electron": 0, "ios": 0, "web": 0} ) - # Check that this is a situation where old R30 differs: - # old R30 DOES count this as 'retained'. - r30_results = self.get_success(store.count_r30_users()) - self.assertEqual(r30_results, {"all": 1, "ios": 1}) - # Now we want to check that the user will still be able to appear in # R30v2 as long as the user performs some other activity between # 30 and 60 days later. diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 15fce165b6..366b6fd5f0 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -11,18 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Mapping, Sequence, Union +from typing import Any, List, Mapping, Optional, Sequence, Union from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor -from synapse.api.errors import HttpResponseException from synapse.appservice import ApplicationService from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.unittest import override_config PROTOCOL = "myproto" TOKEN = "myastoken" @@ -40,7 +40,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): hs_token=TOKEN, ) - def test_query_3pe_authenticates_token(self) -> None: + def test_query_3pe_authenticates_token_via_header(self) -> None: """ Tests that 3pe queries to the appservice are authenticated with the appservice's token. @@ -75,12 +75,18 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): args: Mapping[Any, Any], headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> List[JsonDict]: - # Ensure the access token is passed as both a header and query arg. - if not headers.get("Authorization") or not args.get(b"access_token"): + # Ensure the access token is passed as a header. + if not headers or not headers.get(b"Authorization"): raise RuntimeError("Access token not provided") + # ... and not as a query param + if b"access_token" in args: + raise RuntimeError( + "Access token should not be passed as a query param." + ) - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) - self.assertEqual(args.get(b"access_token"), TOKEN) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) self.request_url = url if url == URL_USER: return SUCCESS_RESULT_USER @@ -92,7 +98,7 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) @@ -107,10 +113,13 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): self.assertEqual(self.request_url, URL_LOCATION) self.assertEqual(result, SUCCESS_RESULT_LOCATION) - def test_fallback(self) -> None: + @override_config({"use_appservice_legacy_authorization": True}) + def test_query_3pe_authenticates_token_via_param(self) -> None: """ - Tests that the fallback to legacy URLs works. + Tests that 3pe queries to the appservice are authenticated + with the appservice's token. """ + SUCCESS_RESULT_USER = [ { "protocol": PROTOCOL, @@ -120,44 +129,63 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): }, } ] + SUCCESS_RESULT_LOCATION = [ + { + "protocol": PROTOCOL, + "alias": "#a:room", + "fields": { + "more": "fields", + }, + } + ] URL_USER = f"{URL}/_matrix/app/v1/thirdparty/user/{PROTOCOL}" - FALLBACK_URL_USER = f"{URL}/_matrix/app/unstable/thirdparty/user/{PROTOCOL}" + URL_LOCATION = f"{URL}/_matrix/app/v1/thirdparty/location/{PROTOCOL}" self.request_url = None - self.v1_seen = False async def get_json( url: str, args: Mapping[Any, Any], - headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], + headers: Optional[ + Mapping[Union[str, bytes], Sequence[Union[str, bytes]]] + ] = None, ) -> List[JsonDict]: - # Ensure the access token is passed as both a header and query arg. - if not headers.get("Authorization") or not args.get(b"access_token"): - raise RuntimeError("Access token not provided") + # Ensure the access token is passed as a both a query param and in the headers. + if not args.get(b"access_token"): + raise RuntimeError("Access token should be provided in query params.") + if not headers or not headers.get(b"Authorization"): + raise RuntimeError("Access token should be provided in auth headers.") - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) self.assertEqual(args.get(b"access_token"), TOKEN) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) self.request_url = url if url == URL_USER: - self.v1_seen = True - raise HttpResponseException(404, "NOT_FOUND", b"NOT_FOUND") - elif url == FALLBACK_URL_USER: return SUCCESS_RESULT_USER + elif url == URL_LOCATION: + return SUCCESS_RESULT_LOCATION else: raise RuntimeError( "URL provided was invalid. This should never be seen." ) # We assign to a method, which mypy doesn't like. - self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment] + self.api.get_json = Mock(side_effect=get_json) # type: ignore[method-assign] result = self.get_success( self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]}) ) - self.assertTrue(self.v1_seen) - self.assertEqual(self.request_url, FALLBACK_URL_USER) + self.assertEqual(self.request_url, URL_USER) self.assertEqual(result, SUCCESS_RESULT_USER) + result = self.get_success( + self.api.query_3pe( + self.service, "location", PROTOCOL, {b"some": [b"field"]} + ) + ) + self.assertEqual(self.request_url, URL_LOCATION) + self.assertEqual(result, SUCCESS_RESULT_LOCATION) def test_claim_keys(self) -> None: """ @@ -184,14 +212,16 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]], ) -> JsonDict: # Ensure the access token is passed as both a header and query arg. - if not headers.get("Authorization"): + if not headers.get(b"Authorization"): raise RuntimeError("Access token not provided") - self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"]) + self.assertEqual( + headers.get(b"Authorization"), [f"Bearer {TOKEN}".encode()] + ) return RESPONSE # We assign to a method, which mypy doesn't like. - self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment] + self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[method-assign] MISSING_KEYS = [ # Known user, known device, missing algorithm. diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index dee976356f..6ac5fc1ae7 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. import re -from typing import Generator -from unittest.mock import Mock +from typing import Any, Generator +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from synapse.appservice import ApplicationService, Namespace from tests import unittest -from tests.test_utils import simple_async_mock def _regex(regex: str, exclusive: bool = True) -> Namespace: @@ -43,21 +42,19 @@ class ApplicationServiceTestCase(unittest.TestCase): ) self.store = Mock() - self.store.get_aliases_for_room = simple_async_mock([]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) @defer.inlineCallbacks def test_regex_user_id_prefix_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -65,15 +62,13 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_user_id_prefix_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.assertFalse( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -81,17 +76,15 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_room_member_is_checked( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@someone_else:matrix.org" self.event.type = "m.room.member" self.event.state_key = "@irc_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -99,17 +92,15 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_room_id_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!some_prefixs0m3th1nGsome_suffix:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -117,38 +108,32 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_room_id_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ROOMS].append( _regex("!some_prefix.*some_suffix:matrix.org") ) self.event.room_id = "!XqBunHwQIXUiqCaoxq:matrix.org" self.assertFalse( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_regex_alias_match( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_regex_alias_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#irc_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @@ -192,14 +177,14 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_alias_no_match( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) - self.store.get_aliases_for_room = simple_async_mock( - ["#xmpp_foobar:matrix.org", "#athing:matrix.org"] + self.store.get_aliases_for_room = AsyncMock( + return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"] ) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertFalse( ( yield defer.ensureDeferred( @@ -213,28 +198,26 @@ class ApplicationServiceTestCase(unittest.TestCase): @defer.inlineCallbacks def test_regex_multiple_matches( self, - ) -> Generator["defer.Deferred[object]", object, None]: + ) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_ALIASES].append( _regex("#irc_.*:matrix.org") ) self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) self.event.sender = "@irc_foobar:matrix.org" - self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"]) - self.store.get_local_users_in_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock( + return_value=["#irc_barfoo:matrix.org"] + ) + self.store.get_local_users_in_room = AsyncMock(return_value=[]) self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_interested_in_self( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_interested_in_self(self) -> Generator["defer.Deferred[Any]", object, None]: # make sure invites get through self.service.sender = "@appservice:name" self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) @@ -243,32 +226,26 @@ class ApplicationServiceTestCase(unittest.TestCase): self.event.state_key = self.service.sender self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) @defer.inlineCallbacks - def test_member_list_match( - self, - ) -> Generator["defer.Deferred[object]", object, None]: + def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]: self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*")) # Note that @irc_fo:here is the AS user. - self.store.get_local_users_in_room = simple_async_mock( - ["@alice:here", "@irc_fo:here", "@bob:here"] + self.store.get_local_users_in_room = AsyncMock( + return_value=["@alice:here", "@irc_fo:here", "@bob:here"] ) - self.store.get_aliases_for_room = simple_async_mock([]) + self.store.get_aliases_for_room = AsyncMock(return_value=[]) self.event.sender = "@xmpp_foobar:matrix.org" self.assertTrue( ( - yield defer.ensureDeferred( - self.service.is_interested_in_event( - self.event.event_id, self.event, self.store - ) + yield self.service.is_interested_in_event( + self.event.event_id, self.event, self.store ) ) ) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index e2a3bad065..445919417e 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Optional, Sequence, Tuple, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from typing_extensions import TypeAlias @@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import simple_async_mock from ..utils import MockClock @@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(True) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=True) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): events = [Mock(), Mock()] txn = Mock(id="idhere", service=service, events=events) - self.store.get_appservice_state = simple_async_mock( - ApplicationServiceState.DOWN + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.DOWN ) - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase): txn = Mock(id=txn_id, service=service, events=events) # mock methods - self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP) - self.store.set_appservice_state = simple_async_mock(True) - txn.send = simple_async_mock(False) # fails to send - self.store.create_appservice_txn = simple_async_mock(txn) + self.store.get_appservice_state = AsyncMock( + return_value=ApplicationServiceState.UP + ) + self.store.set_appservice_state = AsyncMock(return_value=True) + txn.send = AsyncMock(return_value=False) # fails to send + self.store.create_appservice_txn = AsyncMock(return_value=txn) # actual call self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events))) @@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.as_api = Mock() self.store = Mock() self.service = Mock() - self.callback = simple_async_mock() + self.callback = AsyncMock() self.recoverer = _Recoverer( clock=cast(Clock, self.clock), as_api=self.as_api, @@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() # shouldn't have called anything prior to waiting for exp backoff self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(True) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=True) + txn.complete = AsyncMock(return_value=None) # wait for exp backoff self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) @@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.recoverer.recover() self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count) - txn.send = simple_async_mock(False) - txn.complete = simple_async_mock(None) + txn.send = AsyncMock(return_value=False) + txn.complete = AsyncMock(return_value=None) self.clock.advance_time(2) self.assertEqual(1, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) @@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): self.assertEqual(3, txn.send.call_count) self.assertEqual(0, txn.complete.call_count) self.assertEqual(0, self.callback.call_count) - txn.send = simple_async_mock(True) # successfully send the txn + txn.send = AsyncMock(return_value=True) # successfully send the txn pop_txn = True # returns the txn the first time, then no more. self.clock.advance_time(16) self.assertEqual(1, txn.send.call_count) # new mock reset call count @@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None: self.scheduler = ApplicationServiceScheduler(hs) self.txn_ctrl = Mock() - self.txn_ctrl.send = simple_async_mock() + self.txn_ctrl.send = AsyncMock() # Replace instantiated _TransactionController instances with our Mock self.scheduler.txn_ctrl = self.txn_ctrl diff --git a/tests/config/test_oauth_delegation.py b/tests/config/test_oauth_delegation.py new file mode 100644 index 0000000000..5c91031746 --- /dev/null +++ b/tests/config/test_oauth_delegation.py @@ -0,0 +1,278 @@ +# Copyright 2023 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import Mock + +from synapse.config import ConfigError +from synapse.config.homeserver import HomeServerConfig +from synapse.module_api import ModuleApi +from synapse.types import JsonDict + +from tests.server import get_clock, setup_test_homeserver +from tests.unittest import TestCase, skip_unless +from tests.utils import default_config + +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + + +# These are a few constants that are used as config parameters in the tests. +SERVER_NAME = "test" +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" + + +class CustomAuthModule: + """A module which registers a password auth provider.""" + + @staticmethod + def parse_config(config: JsonDict) -> None: + pass + + def __init__(self, config: None, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("m.login.password", ("password",)): Mock()}, + ) + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MSC3861OAuthDelegation(TestCase): + """Test that the Homeserver fails to initialize if the config is invalid.""" + + def setUp(self) -> None: + self.config_dict: JsonDict = { + **default_config("test"), + "public_baseurl": BASE_URL, + "enable_registration": False, + "experimental_features": { + "msc3861": { + "enabled": True, + "issuer": ISSUER, + "client_id": CLIENT_ID, + "client_auth_method": "client_secret_post", + "client_secret": CLIENT_SECRET, + } + }, + } + + def parse_config(self) -> HomeServerConfig: + config = HomeServerConfig() + config.parse_config_dict(self.config_dict, "", "") + return config + + def test_client_secret_post_works(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_post", + client_secret=CLIENT_SECRET, + ) + + self.parse_config() + + def test_client_secret_post_requires_client_secret(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_post", + client_secret=None, + ) + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_client_secret_basic_works(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_basic", + client_secret=CLIENT_SECRET, + ) + + self.parse_config() + + def test_client_secret_basic_requires_client_secret(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_basic", + client_secret=None, + ) + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_client_secret_jwt_works(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_jwt", + client_secret=CLIENT_SECRET, + ) + + self.parse_config() + + def test_client_secret_jwt_requires_client_secret(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="client_secret_jwt", + client_secret=None, + ) + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_invalid_client_auth_method(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="invalid", + ) + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_private_key_jwt_requires_jwk(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="private_key_jwt", + ) + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_private_key_jwt_works(self) -> None: + self.config_dict["experimental_features"]["msc3861"].update( + client_auth_method="private_key_jwt", + jwk={ + "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8", + "kty": "RSA", + "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU", + "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ", + "e": "AQAB", + "kid": "test", + "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI", + "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8", + "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE", + "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw", + }, + ) + self.parse_config() + + def test_registration_cannot_be_enabled(self) -> None: + self.config_dict["enable_registration"] = True + with self.assertRaises(ConfigError): + self.parse_config() + + def test_user_consent_cannot_be_enabled(self) -> None: + tmpdir = self.mktemp() + os.mkdir(tmpdir) + self.config_dict["user_consent"] = { + "require_at_registration": True, + "version": "1", + "template_dir": tmpdir, + "server_notice_content": { + "msgtype": "m.text", + "body": "foo", + }, + } + with self.assertRaises(ConfigError): + self.parse_config() + + def test_password_config_cannot_be_enabled(self) -> None: + self.config_dict["password_config"] = {"enabled": True} + with self.assertRaises(ConfigError): + self.parse_config() + + def test_oidc_sso_cannot_be_enabled(self) -> None: + self.config_dict["oidc_providers"] = [ + { + "idp_id": "microsoft", + "idp_name": "Microsoft", + "issuer": "https://login.microsoftonline.com/<tenant id>/v2.0", + "client_id": "<client id>", + "client_secret": "<client secret>", + "scopes": ["openid", "profile"], + "authorization_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/authorize", + "token_endpoint": "https://login.microsoftonline.com/<tenant id>/oauth2/v2.0/token", + "userinfo_endpoint": "https://graph.microsoft.com/oidc/userinfo", + } + ] + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_cas_sso_cannot_be_enabled(self) -> None: + self.config_dict["cas_config"] = { + "enabled": True, + "server_url": "https://cas-server.com", + "displayname_attribute": "name", + "required_attributes": {"userGroup": "staff", "department": "None"}, + } + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_auth_providers_cannot_be_enabled(self) -> None: + self.config_dict["modules"] = [ + { + "module": f"{__name__}.{CustomAuthModule.__qualname__}", + "config": {}, + } + ] + + # This requires actually setting up an HS, as the module will be run on setup, + # which should raise as the module tries to register an auth provider + config = self.parse_config() + reactor, clock = get_clock() + with self.assertRaises(ConfigError): + setup_test_homeserver( + self.addCleanup, reactor=reactor, clock=clock, config=config + ) + + def test_jwt_auth_cannot_be_enabled(self) -> None: + self.config_dict["jwt_config"] = { + "enabled": True, + "secret": "my-secret-token", + "algorithm": "HS256", + } + + with self.assertRaises(ConfigError): + self.parse_config() + + def test_login_via_existing_session_cannot_be_enabled(self) -> None: + self.config_dict["login_via_existing_session"] = {"enabled": True} + with self.assertRaises(ConfigError): + self.parse_config() + + def test_captcha_cannot_be_enabled(self) -> None: + self.config_dict.update( + enable_registration_captcha=True, + recaptcha_public_key="test", + recaptcha_private_key="test", + ) + with self.assertRaises(ConfigError): + self.parse_config() + + def test_refreshable_tokens_cannot_be_enabled(self) -> None: + self.config_dict.update( + refresh_token_lifetime="24h", + refreshable_access_token_lifetime="10m", + nonrefreshable_access_token_lifetime="24h", + ) + with self.assertRaises(ConfigError): + self.parse_config() + + def test_session_lifetime_cannot_be_set(self) -> None: + self.config_dict["session_lifetime"] = "24h" + with self.assertRaises(ConfigError): + self.parse_config() + + def test_enable_3pid_changes_cannot_be_enabled(self) -> None: + self.config_dict["enable_3pid_changes"] = True + with self.assertRaises(ConfigError): + self.parse_config() diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index f12147eaa0..0c27dd21e2 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -12,11 +12,42 @@ # See the License for the specific language governing permissions and # limitations under the License. from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import RatelimitSettings from tests.unittest import TestCase from tests.utils import default_config +class ParseRatelimitSettingsTestcase(TestCase): + def test_depth_1(self) -> None: + cfg = { + "a": { + "per_second": 5, + "burst_count": 10, + } + } + parsed = RatelimitSettings.parse(cfg, "a") + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + def test_depth_2(self) -> None: + cfg = { + "a": { + "b": { + "per_second": 5, + "burst_count": 10, + }, + } + } + parsed = RatelimitSettings.parse(cfg, "a.b") + self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10)) + + def test_missing(self) -> None: + parsed = RatelimitSettings.parse( + {}, "a", defaults={"per_second": 5, "burst_count": 10} + ) + self.assertEqual(parsed, RatelimitSettings("a", 5, 10)) + + class RatelimitConfigTestCase(TestCase): def test_parse_rc_federation(self) -> None: config_dict = default_config("test") diff --git a/tests/config/test_workers.py b/tests/config/test_workers.py index 49a6bdf408..2a643ae4f3 100644 --- a/tests/config/test_workers.py +++ b/tests/config/test_workers.py @@ -94,6 +94,7 @@ class WorkerDutyConfigTestCase(TestCase): # so that it doesn't raise an exception here. # (This is not read by `_should_this_worker_perform_duty`.) "notify_appservices": False, + "instance_map": {"main": {"host": "127.0.0.1", "port": 0}}, }, ) @@ -138,7 +139,9 @@ class WorkerDutyConfigTestCase(TestCase): """ main_process_config = self._make_worker_config( - worker_app="synapse.app.homeserver", worker_name=None + worker_app="synapse.app.homeserver", + worker_name=None, + extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}}, ) self.assertTrue( @@ -203,6 +206,7 @@ class WorkerDutyConfigTestCase(TestCase): # so that it doesn't raise an exception here. # (This is not read by `_should_this_worker_perform_duty`.) "notify_appservices": False, + "instance_map": {"main": {"host": "127.0.0.1", "port": 0}}, }, ) @@ -236,7 +240,9 @@ class WorkerDutyConfigTestCase(TestCase): Tests new config options. This is for the master's config. """ main_process_config = self._make_worker_config( - worker_app="synapse.app.homeserver", worker_name=None + worker_app="synapse.app.homeserver", + worker_name=None, + extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}}, ) self.assertTrue( @@ -262,7 +268,9 @@ class WorkerDutyConfigTestCase(TestCase): Tests new config options. This is for the worker's config. """ appservice_worker_config = self._make_worker_config( - worker_app="synapse.app.generic_worker", worker_name="worker1" + worker_app="synapse.app.generic_worker", + worker_name="worker1", + extras={"instance_map": {"main": {"host": "127.0.0.1", "port": 0}}}, ) self.assertTrue( @@ -298,6 +306,7 @@ class WorkerDutyConfigTestCase(TestCase): extras={ "notify_appservices_from_worker": "worker2", "update_user_directory_from_worker": "worker1", + "instance_map": {"main": {"host": "127.0.0.1", "port": 0}}, }, ) self.assertFalse(worker1_config.should_notify_appservices) @@ -309,6 +318,7 @@ class WorkerDutyConfigTestCase(TestCase): extras={ "notify_appservices_from_worker": "worker2", "update_user_directory_from_worker": "worker1", + "instance_map": {"main": {"host": "127.0.0.1", "port": 0}}, }, ) self.assertTrue(worker2_config.should_notify_appservices) diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 7c63b2ea4c..c5700771b0 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -45,7 +45,6 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import logcontext_clean, override_config @@ -190,23 +189,24 @@ class KeyringTestCase(unittest.HomeserverTestCase): kr = keyring.Keyring(self.hs) key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_keys_json( + r = self.hs.get_datastores().main.store_server_keys_response( "server9", - get_key_id(key1), from_server="test", - ts_now_ms=int(time.time() * 1000), - ts_expires_ms=1000, + ts_added_ms=int(time.time() * 1000), + verify_keys={ + get_key_id(key1): FetchKeyResult( + verify_key=get_verify_key(key1), valid_until_ts=1000 + ) + }, # The entire response gets signed & stored, just include the bits we # care about. - key_json_bytes=canonicaljson.encode_canonical_json( - { - "verify_keys": { - get_key_id(key1): { - "key": encode_verify_key_base64(get_verify_key(key1)) - } + response_json={ + "verify_keys": { + get_key_id(key1): { + "key": encode_verify_key_base64(get_verify_key(key1)) } } - ), + }, ) self.get_success(r) @@ -286,34 +286,6 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = kr.verify_json_for_server(self.hs.hostname, json1, 0) self.get_success(d) - def test_verify_json_for_server_with_null_valid_until_ms(self) -> None: - """Tests that we correctly handle key requests for keys we've stored - with a null `ts_valid_until_ms` - """ - mock_fetcher = Mock() - mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) - - key1 = signedjson.key.generate_signing_key("1") - r = self.hs.get_datastores().main.store_server_signature_keys( - "server9", - int(time.time() * 1000), - # None is not a valid value in FetchKeyResult, but we're abusing this - # API to insert null values into the database. The nulls get converted - # to 0 when fetched in KeyStore.get_server_signature_keys. - {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)}, # type: ignore[arg-type] - ) - self.get_success(r) - - json1: JsonDict = {} - signedjson.sign.sign_json(json1, "server9", key1) - - # should succeed on a signed object with a 0 minimum_valid_until_ms - d = self.hs.get_datastores().main.get_server_signature_keys( - [("server9", get_key_id(key1))] - ) - result = self.get_success(d) - self.assertEquals(result[("server9", get_key_id(key1))].valid_until_ts, 0) - def test_verify_json_dedupes_key_requests(self) -> None: """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key("1") @@ -456,24 +428,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], SERVER_NAME) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) # we expect it to be encoded as canonical json *before* it hits the db - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" @@ -576,23 +543,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) - - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) + + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_get_multiple_keys_from_perspectives(self) -> None: """Check that we can correctly request multiple keys for the same server""" @@ -699,23 +661,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.assertEqual(k.verify_key.version, "ver1") # check that the perspectives store is correctly updated - lookup_triplet = (SERVER_NAME, testverifykey_id, None) key_json = self.get_success( self.hs.get_datastores().main.get_server_keys_json_for_remote( - [lookup_triplet] + SERVER_NAME, [testverifykey_id] ) ) - res_keys = key_json[lookup_triplet] - self.assertEqual(len(res_keys), 1) - res = res_keys[0] - self.assertEqual(res["key_id"], testverifykey_id) - self.assertEqual(res["from_server"], self.mock_perspective_server.server_name) - self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000) - self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS) - - self.assertEqual( - bytes(res["key_json"]), canonicaljson.encode_canonical_json(response) - ) + res = key_json[testverifykey_id] + self.assertIsNotNone(res) + assert res is not None + self.assertEqual(res.added_ts, self.reactor.seconds() * 1000) + self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS) + + self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response)) def test_invalid_perspectives_responses(self) -> None: """Check that invalid responses from the perspectives server are rejected""" diff --git a/tests/events/test_presence_router.py b/tests/events/test_presence_router.py index 6fb1f1bd6e..0fcfe38efa 100644 --- a/tests/events/test_presence_router.py +++ b/tests/events/test_presence_router.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Dict, Iterable, List, Optional, Set, Tuple, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester from synapse.util import Clock from tests.handlers.test_sync import generate_sync_config -from tests.test_utils import simple_async_mock from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) hs = self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py index 6687c28e8f..b5e42f9600 100644 --- a/tests/events/test_snapshot.py +++ b/tests/events/test_snapshot.py @@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase): self.assertEqual( context.state_group_before_event, d_context.state_group_before_event ) - self.assertEqual(context.prev_group, d_context.prev_group) - self.assertEqual(context.delta_ids, d_context.delta_ids) + self.assertEqual(context.state_group_deltas, d_context.state_group_deltas) self.assertEqual(context.app_service, d_context.app_service) self.assertEqual( diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py index e40eac2eb0..978612e432 100644 --- a/tests/events/test_utils.py +++ b/tests/events/test_utils.py @@ -16,6 +16,7 @@ import unittest as stdlib_unittest from typing import Any, List, Mapping, Optional import attr +from parameterized import parameterized from synapse.api.constants import EventContentFields from synapse.api.room_versions import RoomVersions @@ -23,6 +24,7 @@ from synapse.events import EventBase, make_event_from_dict from synapse.events.utils import ( PowerLevelsContent, SerializeEventConfig, + _split_field, copy_and_fixup_power_levels_contents, maybe_upsert_event_field, prune_event, @@ -138,18 +140,16 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - # As of MSC2176 we now redact the membership and prev_states keys. + # As of room versions we now redact the membership, prev_states, and origin keys. self.run_test( - {"type": "A", "prev_state": "prev_state", "membership": "join"}, - {"type": "A", "content": {}, "signatures": {}, "unsigned": {}}, - room_version=RoomVersions.MSC2176, - ) - - # As of MSC3989 we now redact the origin key. - self.run_test( - {"type": "A", "origin": "example.com"}, + { + "type": "A", + "prev_state": "prev_state", + "membership": "join", + "origin": "example.com", + }, {"type": "A", "content": {}, "signatures": {}, "unsigned": {}}, - room_version=RoomVersions.MSC3989, + room_version=RoomVersions.V11, ) def test_unsigned(self) -> None: @@ -225,16 +225,21 @@ class PruneEventTestCase(stdlib_unittest.TestCase): }, ) - # After MSC2176, create events get nothing redacted. + # After MSC2176, create events should preserve field `content` self.run_test( - {"type": "m.room.create", "content": {"not_a_real_key": True}}, + { + "type": "m.room.create", + "content": {"not_a_real_key": True}, + "origin": "some_homeserver", + "nonsense_field": "some_random_garbage", + }, { "type": "m.room.create", "content": {"not_a_real_key": True}, "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC2176, + room_version=RoomVersions.V11, ) def test_power_levels(self) -> None: @@ -284,7 +289,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC2176, + room_version=RoomVersions.V11, ) def test_alias_event(self) -> None: @@ -347,7 +352,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC2176, + room_version=RoomVersions.V11, ) def test_join_rules(self) -> None: @@ -470,7 +475,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC3821, + room_version=RoomVersions.V11, ) # Ensure this doesn't break if an invalid field is sent. @@ -489,7 +494,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC3821, + room_version=RoomVersions.V11, ) self.run_test( @@ -507,7 +512,7 @@ class PruneEventTestCase(stdlib_unittest.TestCase): "signatures": {}, "unsigned": {}, }, - room_version=RoomVersions.MSC3821, + room_version=RoomVersions.V11, ) def test_relations(self) -> None: @@ -794,3 +799,40 @@ class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase): def test_invalid_nesting_raises_type_error(self) -> None: with self.assertRaises(TypeError): copy_and_fixup_power_levels_contents({"a": {"b": {"c": 1}}}) # type: ignore[dict-item] + + +class SplitFieldTestCase(stdlib_unittest.TestCase): + @parameterized.expand( + [ + # A field with no dots. + ["m", ["m"]], + # Simple dotted fields. + ["m.foo", ["m", "foo"]], + ["m.foo.bar", ["m", "foo", "bar"]], + # Backslash is used as an escape character. + [r"m\.foo", ["m.foo"]], + [r"m\\.foo", ["m\\", "foo"]], + [r"m\\\.foo", [r"m\.foo"]], + [r"m\\\\.foo", ["m\\\\", "foo"]], + [r"m\foo", [r"m\foo"]], + [r"m\\foo", [r"m\foo"]], + [r"m\\\foo", [r"m\\foo"]], + [r"m\\\\foo", [r"m\\foo"]], + # Ensure that escapes at the end don't cause issues. + ["m.foo\\", ["m", "foo\\"]], + ["m.foo\\", ["m", "foo\\"]], + [r"m.foo\.", ["m", "foo."]], + [r"m.foo\\.", ["m", "foo\\", ""]], + [r"m.foo\\\.", ["m", r"foo\."]], + # Empty parts (corresponding to properties which are an empty string) are allowed. + [".m", ["", "m"]], + ["..m", ["", "", "m"]], + ["m.", ["m", ""]], + ["m..", ["m", "", ""]], + ["m..foo", ["m", "", "foo"]], + # Invalid escape sequences. + [r"\m", [r"\m"]], + ] + ) + def test_split_field(self, input: str, expected: str) -> None: + self.assertEqual(_split_field(input), expected) diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 129d7cfd93..73a2766baf 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock from synapse.api.errors import Codes, SynapseError from synapse.rest import admin @@ -20,7 +20,6 @@ from synapse.rest.client import login, room from synapse.types import JsonDict, UserID, create_requester from tests import unittest -from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -58,7 +57,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): async def get_current_state_event_counts(room_id: str) -> int: return int(500 * 1.23) - store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] # Get the room complexity again -- make sure it's our artificial value channel = self.make_signed_federation_request( @@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -143,16 +142,16 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) # Artificially raise the complexity async def get_current_state_event_counts(room_id: str) -> int: return 600 - self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment] + self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign] d = handler._remote_join( create_requester(u1), @@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( @@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment] - handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment] - return_value=make_awaitable(("", 1)) + fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign] + handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign] + return_value=("", 1) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py index 391ae51707..75ae740b43 100644 --- a/tests/federation/test_federation_catch_up.py +++ b/tests/federation/test_federation_catch_up.py @@ -1,6 +1,6 @@ from typing import Callable, Collection, List, Optional, Tuple from unittest import mock -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -19,7 +19,7 @@ from synapse.types import JsonDict from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection from tests.unittest import FederatingHomeserverTestCase @@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # This mock is crucial for destination_rooms to be populated. # TODO: this seems to no longer be the case---tests pass with this mock # commented out. - state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment] - return_value=make_awaitable({"test", "host2"}) + state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] + return_value={"test", "host2"} ) # whenever send_transaction is called, record the pdu data @@ -431,28 +431,24 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase): # ACT: call _wake_destinations_needing_catchup # patch wake_destination to just count the destinations instead - woken = [] + woken = set() def wake_destination_track(destination: str) -> None: - woken.append(destination) + woken.add(destination) - self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment] + self.federation_sender.wake_destination = wake_destination_track # type: ignore[method-assign] - # cancel the pre-existing timer for _wake_destinations_needing_catchup - # this is because we are calling it manually rather than waiting for it - # to be called automatically - assert self.federation_sender._catchup_after_startup_timer is not None - self.federation_sender._catchup_after_startup_timer.cancel() - - self.get_success( - self.federation_sender._wake_destinations_needing_catchup(), by=5.0 - ) + # We wait quite long so that all dests can be woken up, since there is a delay + # between them. + self.pump(by=5.0) # ASSERT (_wake_destinations_needing_catchup): # - all remotes are woken up, save for zzzerver self.assertNotIn("zzzerver", woken) - # - all destinations are woken exactly once; they appear once in woken. - self.assertCountEqual(woken, server_names[:-1]) + # - all destinations are woken, potentially more than once, since the + # wake up is called regularly and we don't ack in this test that a transaction + # has been successfully sent. + self.assertCountEqual(woken, set(server_names[:-1])) def test_not_latest_event(self) -> None: """Test that we send the latest event in the room even if its not ours.""" diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index 91694e4fca..a45ab83683 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -124,7 +124,7 @@ class FederationClientTest(FederatingHomeserverTestCase): # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( b"GET", - b"matrix://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id", + b"matrix-federation://yet.another.server/_matrix/federation/v1/state/%21room_id?event_id=event_id", headers=mock.ANY, bodyProducer=None, ) @@ -232,7 +232,7 @@ class FederationClientTest(FederatingHomeserverTestCase): # check the right call got made to the agent self._mock_agent.request.assert_called_once_with( b"GET", - b"matrix://yet.another.server/_matrix/federation/v1/event/event_id", + b"matrix-federation://yet.another.server/_matrix/federation/v1/event/event_id", headers=mock.ANY, bodyProducer=None, ) diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 9e104fd96a..caf04b54cb 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Callable, FrozenSet, List, Optional, Set -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from signedjson import key, sign from signedjson.types import BaseKey, SigningKey @@ -29,7 +29,6 @@ from synapse.server import HomeServer from synapse.types import JsonDict, ReadReceipt from synapse.util import Clock -from tests.test_utils import make_awaitable from tests.unittest import HomeserverTestCase @@ -43,15 +42,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.federation_transport_client = Mock(spec=["send_transaction"]) + self.federation_transport_client.send_transaction = AsyncMock() hs = self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] - return_value=make_awaitable({"test", "host2"}) + hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign] + return_value={"test", "host2"} ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign] hs.get_storage_controllers().state.get_current_hosts_in_room ) @@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -75,7 +75,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=None, data={"ts": 1234}, ) - self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) + self.get_success(sender.send_read_receipt(receipt)) self.pump() @@ -104,13 +104,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): def test_send_receipts_thread(self) -> None: mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} # Create receipts for: # # * The same room / user on multiple threads. # * A different user in the same room. sender = self.hs.get_federation_sender() + # Hack so that we have a txn in-flight so we batch up read receipts + # below + sender.wake_destination("host2") for user, thread in ( ("alice", None), ("alice", "thread"), @@ -125,9 +128,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=thread, data={"ts": 1234}, ) - self.successResultOf( - defer.ensureDeferred(sender.send_read_receipt(receipt)) - ) + defer.ensureDeferred(sender.send_read_receipt(receipt)) self.pump() @@ -180,7 +181,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): """Send two receipts in quick succession; the second should be flushed, but only after 20ms""" mock_send_transaction = self.federation_transport_client.send_transaction - mock_send_transaction.return_value = make_awaitable({}) + mock_send_transaction.return_value = {} sender = self.hs.get_federation_sender() receipt = ReadReceipt( @@ -191,7 +192,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): thread_id=None, data={"ts": 1234}, ) - self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) + self.get_success(sender.send_read_receipt(receipt)) self.pump() @@ -276,6 +277,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.federation_transport_client = Mock( spec=["send_transaction", "query_user_devices"] ) + self.federation_transport_client.send_transaction = AsyncMock() + self.federation_transport_client.query_user_devices = AsyncMock() return self.setup_test_homeserver( federation_transport_client=self.federation_transport_client, ) @@ -317,13 +320,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.record_transaction ) - def record_transaction( + async def record_transaction( self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None - ) -> "defer.Deferred[JsonDict]": + ) -> JsonDict: assert json_cb is not None data = json_cb() self.edus.extend(data["edus"]) - return defer.succeed({}) + return {} def test_send_device_updates(self) -> None: """Basic case: each device update should result in an EDU""" @@ -340,7 +343,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): self.reactor.advance(1) # a second call should produce no new device EDUs - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) self.assertEqual(self.edus, []) # a second device @@ -354,15 +359,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. - self.federation_transport_client.query_user_devices.return_value = ( - make_awaitable( - { - "stream_id": "1", - "user_id": "@user2:host2", - "devices": [{"device_id": "D1"}], - } - ) - ) + self.federation_transport_client.query_user_devices.return_value = { + "stream_id": "1", + "user_id": "@user2:host2", + "devices": [{"device_id": "D1"}], + } self.get_success( self.device_handler.device_list_updater.incoming_device_list_update( @@ -533,7 +534,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): recovery """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") # create devices u1 = self.register_user("user", "pass") @@ -552,7 +553,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -578,7 +581,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): This case tests the behaviour when the server has never been reachable. """ mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") # create devices u1 = self.register_user("user", "pass") @@ -603,7 +606,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. @@ -636,7 +641,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # now the server goes offline mock_send_txn = self.federation_transport_client.send_transaction - mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail")) + mock_send_txn.side_effect = AssertionError("fail") self.login("user", "pass", device_id="D2") self.login("user", "pass", device_id="D3") @@ -658,7 +663,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # recover the server mock_send_txn.side_effect = self.record_transaction - self.hs.get_federation_sender().send_device_messages("host2") + self.get_success( + self.hs.get_federation_sender().send_device_messages(["host2"]) + ) # We queue up device list updates to be sent over federation, so we # advance to clear the queue. diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 5c850d1843..1831a5b47a 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -22,10 +22,10 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.events import EventBase, make_event_from_dict -from synapse.federation.federation_server import server_matches_acl_event from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer +from synapse.storage.controllers.state import server_acl_evaluator_from_event from synapse.types import JsonDict from synapse.util import Clock @@ -67,37 +67,46 @@ class ServerACLsTestCase(unittest.TestCase): e = _create_acl_event({"allow": ["*"], "deny": ["evil.com"]}) logging.info("ACL event: %s", e.content) - self.assertFalse(server_matches_acl_event("evil.com", e)) - self.assertFalse(server_matches_acl_event("EVIL.COM", e)) + server_acl_evalutor = server_acl_evaluator_from_event(e) - self.assertTrue(server_matches_acl_event("evil.com.au", e)) - self.assertTrue(server_matches_acl_event("honestly.not.evil.com", e)) + self.assertFalse(server_acl_evalutor.server_matches_acl_event("evil.com")) + self.assertFalse(server_acl_evalutor.server_matches_acl_event("EVIL.COM")) + + self.assertTrue(server_acl_evalutor.server_matches_acl_event("evil.com.au")) + self.assertTrue( + server_acl_evalutor.server_matches_acl_event("honestly.not.evil.com") + ) def test_block_ip_literals(self) -> None: e = _create_acl_event({"allow_ip_literals": False, "allow": ["*"]}) logging.info("ACL event: %s", e.content) - self.assertFalse(server_matches_acl_event("1.2.3.4", e)) - self.assertTrue(server_matches_acl_event("1a.2.3.4", e)) - self.assertFalse(server_matches_acl_event("[1:2::]", e)) - self.assertTrue(server_matches_acl_event("1:2:3:4", e)) + server_acl_evalutor = server_acl_evaluator_from_event(e) + + self.assertFalse(server_acl_evalutor.server_matches_acl_event("1.2.3.4")) + self.assertTrue(server_acl_evalutor.server_matches_acl_event("1a.2.3.4")) + self.assertFalse(server_acl_evalutor.server_matches_acl_event("[1:2::]")) + self.assertTrue(server_acl_evalutor.server_matches_acl_event("1:2:3:4")) def test_wildcard_matching(self) -> None: e = _create_acl_event({"allow": ["good*.com"]}) + + server_acl_evalutor = server_acl_evaluator_from_event(e) + self.assertTrue( - server_matches_acl_event("good.com", e), + server_acl_evalutor.server_matches_acl_event("good.com"), "* matches 0 characters", ) self.assertTrue( - server_matches_acl_event("GOOD.COM", e), + server_acl_evalutor.server_matches_acl_event("GOOD.COM"), "pattern is case-insensitive", ) self.assertTrue( - server_matches_acl_event("good.aa.com", e), + server_acl_evalutor.server_matches_acl_event("good.aa.com"), "* matches several characters, including '.'", ) self.assertFalse( - server_matches_acl_event("ishgood.com", e), + server_acl_evalutor.server_matches_acl_event("ishgood.com"), "pattern does not allow prefixes", ) diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py index 70209ab090..b63ef3d4ed 100644 --- a/tests/federation/transport/test_knocking.py +++ b/tests/federation/transport/test_knocking.py @@ -218,7 +218,7 @@ class FederationKnockingTestCase( ) -> EventBase: return pdu - homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment] + homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[method-assign] approve_all_signature_checking ) @@ -229,7 +229,7 @@ class FederationKnockingTestCase( ) -> None: pass - homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return super().prepare(reactor, clock, homeserver) @@ -308,7 +308,7 @@ class FederationKnockingTestCase( self.assertEqual(200, channel.code, channel.result) # Check that we got the stripped room state in return - room_state_events = channel.json_body["knock_state_events"] + room_state_events = channel.json_body["knock_room_state"] # Validate the stripped room state events self.check_knock_room_state_against_room_state( diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 9014e60577..867dbd6001 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Dict, Iterable, List, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -31,12 +31,12 @@ from synapse.appservice import ( from synapse.handlers.appservice import ApplicationServicesHandler from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection, make_awaitable, simple_async_mock +from tests.test_utils import event_injection from tests.unittest import override_config from tests.utils import MockClock @@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase): def setUp(self) -> None: self.mock_store = Mock() - self.mock_as_api = Mock() + self.mock_as_api = AsyncMock() self.mock_scheduler = Mock() hs = Mock() hs.get_datastores.return_value = Mock(main=self.mock_store) - self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None) - self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None) - self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable( - None - ) + self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) + self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) + self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None) hs.get_application_service_api.return_value = self.mock_as_api hs.get_application_service_scheduler.return_value = self.mock_scheduler hs.get_clock.return_value = MockClock() @@ -69,22 +67,26 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice(is_interested_in_event=False), ] - self.mock_as_api.query_user.return_value = make_awaitable(True) + self.mock_as_api.query_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable([]) + self.mock_store.get_user_by_id = AsyncMock(return_value=[]) event = Mock( sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, {})), - make_awaitable((1, {event.event_id: 0})), - ] - self.mock_store.get_events_as_list.side_effect = [ - make_awaitable([]), - make_awaitable([event]), - ] - self.handler.notify_interested_services(RoomStreamToken(None, 1)) + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, {}), + (1, {event.event_id: 0}), + ] + ) + self.mock_store.get_events_as_list = AsyncMock( + side_effect=[ + [], + [event], + ] + ) + self.handler.notify_interested_services(RoomStreamToken(stream=1)) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, events=[event] @@ -95,15 +97,17 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable(None) + self.mock_store.get_user_by_id = AsyncMock(return_value=None) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, {event.event_id: 0})), - ] - self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])] - self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.mock_as_api.query_user.return_value = True + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, {event.event_id: 0}), + ] + ) + self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]]) + self.handler.notify_interested_services(RoomStreamToken(stream=0)) self.mock_as_api.query_user.assert_called_once_with(services[0], user_id) @@ -112,15 +116,17 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [self._mkservice(is_interested_in_event=True)] services[0].is_interested_in_user.return_value = True self.mock_store.get_app_services.return_value = services - self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id}) + self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id}) event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar") - self.mock_as_api.query_user.return_value = make_awaitable(True) - self.mock_store.get_all_new_event_ids_stream.side_effect = [ - make_awaitable((0, [event], {event.event_id: 0})), - ] + self.mock_as_api.query_user.return_value = True + self.mock_store.get_all_new_event_ids_stream = AsyncMock( + side_effect=[ + (0, [event], {event.event_id: 0}), + ] + ) - self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.handler.notify_interested_services(RoomStreamToken(stream=0)) self.assertFalse( self.mock_as_api.query_user.called, @@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): self._mkservice_alias(is_room_alias_in_namespace=False), ] - self.mock_as_api.query_alias.return_value = make_awaitable(True) + self.mock_as_api.query_alias = AsyncMock(return_value=True) self.mock_store.get_app_services.return_value = services - self.mock_store.get_association_from_room_alias.return_value = make_awaitable( - Mock(room_id=room_id, servers=servers) + self.mock_store.get_association_from_room_alias = AsyncMock( + return_value=Mock(room_id=room_id, servers=servers) ) result = self.successResultOf( @@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_get_3pe_protocols_protocol_no_response(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + self.mock_as_api.get_3pe_protocol.return_value = None response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_get_3pe_protocols_select_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) ) @@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): def test_get_3pe_protocols_one_protocol(self) -> None: service = self._mkservice(False, ["my-protocol"]) self.mock_store.get_app_services.return_value = [service] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): service_one = self._mkservice(False, ["my-protocol"]) service_two = self._mkservice(False, ["other-protocol"]) self.mock_store.get_app_services.return_value = [service_one, service_two] - self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( - {"x-protocol-data": 42, "instances": []} - ) + self.mock_as_api.get_3pe_protocol.return_value = { + "x-protocol-data": 42, + "instances": [], + } response = self.successResultOf( defer.ensureDeferred(self.handler.get_3pe_protocols()) ) @@ -287,17 +296,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): interested_service = self._mkservice(is_interested_in_event=True) services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( - 579 - ) + self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as.return_value = ( - make_awaitable(([event], None)) + self.event_source.sources.receipt.get_new_events_as = AsyncMock( + return_value=([event], None) ) self.handler.notify_interested_services_ephemeral( - "receipt_key", 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] ) self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( interested_service, ephemeral=[event] @@ -317,17 +324,15 @@ class AppServiceHandlerTestCase(unittest.TestCase): services = [interested_service] self.mock_store.get_app_services.return_value = services - self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable( - 580 - ) + self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580) event = Mock(event_id="event_1") - self.event_source.sources.receipt.get_new_events_as.return_value = ( - make_awaitable(([event], None)) + self.event_source.sources.receipt.get_new_events_as = AsyncMock( + return_value=([event], None) ) self.handler.notify_interested_services_ephemeral( - "receipt_key", 580, ["@fakerecipient:example.com"] + StreamKeyType.RECEIPT, 580, ["@fakerecipient:example.com"] ) # This method will be called, but with an empty list of events self.mock_scheduler.enqueue_for_appservice.assert_called_once_with( @@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase): A mock representing the ApplicationService. """ service = Mock() - service.is_interested_in_event.return_value = make_awaitable( - is_interested_in_event - ) + service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event) service.token = "mock_service_token" service.url = "mock_service_url" service.protocols = protocols @@ -396,12 +399,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.hs = hs # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track any outgoing ephemeral events - self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] + self.send_mock = AsyncMock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -419,6 +422,18 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + self.exclusive_as_user_2_device_id = "exclusive_as_device_2" + self.exclusive_as_user_2 = self.register_user("exclusive_as_user_2", "password") + self.exclusive_as_user_2_token = self.login( + "exclusive_as_user_2", "password", self.exclusive_as_user_2_device_id + ) + + self.exclusive_as_user_3_device_id = "exclusive_as_device_3" + self.exclusive_as_user_3 = self.register_user("exclusive_as_user_3", "password") + self.exclusive_as_user_3_token = self.login( + "exclusive_as_user_3", "password", self.exclusive_as_user_3_device_id + ) + def _notify_interested_services(self) -> None: # This is normally set in `notify_interested_services` but we need to call the # internal async version so the reactor gets pushed to completion. @@ -426,7 +441,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.get_success( self.hs.get_application_service_handler()._notify_interested_services( RoomStreamToken( - None, self.hs.get_application_service_handler().current_max + stream=self.hs.get_application_service_handler().current_max ) ) ) @@ -619,7 +634,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): self.get_success( self.hs.get_application_service_handler()._notify_interested_services_ephemeral( services=[interested_appservice], - stream_key="receipt_key", + stream_key=StreamKeyType.RECEIPT, new_token=stream_token, users=[self.exclusive_as_user], ) @@ -846,6 +861,119 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): for count in service_id_to_message_count.values(): self.assertEqual(count, number_of_messages) + @unittest.override_config( + {"experimental_features": {"msc2409_to_device_messages_enabled": True}} + ) + def test_application_services_receive_local_to_device_for_many_users(self) -> None: + """ + Test that when a user sends a to-device message to many users + in an application service's user namespace, the + application service will receive all of them. + """ + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + }, + { + "regex": "@exclusive_as_user_2:.+", + "exclusive": True, + }, + { + "regex": "@exclusive_as_user_3:.+", + "exclusive": True, + }, + ], + }, + ) + + # Have local_user send a to-device message to exclusive_as_users + message_content = {"some_key": "some really interesting value"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/3", + content={ + "messages": { + self.exclusive_as_user: { + self.exclusive_as_user_device_id: message_content + }, + self.exclusive_as_user_2: { + self.exclusive_as_user_2_device_id: message_content + }, + self.exclusive_as_user_3: { + self.exclusive_as_user_3_device_id: message_content + }, + } + }, + access_token=self.local_user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Have exclusive_as_user send a to-device message to local_user + for user_token in [ + self.exclusive_as_user_token, + self.exclusive_as_user_2_token, + self.exclusive_as_user_3_token, + ]: + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.room_key_request/4", + content={ + "messages": { + self.local_user: {self.local_user_device_id: message_content} + } + }, + access_token=user_token, + ) + self.assertEqual(chan.code, 200, chan.result) + + # Check if our application service - that is interested in exclusive_as_user - received + # the to-device message as part of an AS transaction. + # Only the local_user -> exclusive_as_user to-device message should have been forwarded to the AS. + # + # The uninterested application service should not have been notified at all. + self.send_mock.assert_called_once() + ( + service, + _events, + _ephemeral, + to_device_messages, + _otks, + _fbks, + _device_list_summary, + ) = self.send_mock.call_args[0] + + # Assert that this was the same to-device message that local_user sent + self.assertEqual(service, interested_appservice) + + # Assert expected number of messages + self.assertEqual(len(to_device_messages), 3) + + for device_msg in to_device_messages: + self.assertEqual(device_msg["type"], "m.room_key_request") + self.assertEqual(device_msg["sender"], self.local_user) + self.assertEqual(device_msg["content"], message_content) + + self.assertEqual(to_device_messages[0]["to_user_id"], self.exclusive_as_user) + self.assertEqual( + to_device_messages[0]["to_device_id"], + self.exclusive_as_user_device_id, + ) + + self.assertEqual(to_device_messages[1]["to_user_id"], self.exclusive_as_user_2) + self.assertEqual( + to_device_messages[1]["to_device_id"], + self.exclusive_as_user_2_device_id, + ) + + self.assertEqual(to_device_messages[2]["to_user_id"], self.exclusive_as_user_3) + self.assertEqual( + to_device_messages[2]["to_device_id"], + self.exclusive_as_user_3_device_id, + ) + def _register_application_service( self, namespaces: Optional[Dict[str, Iterable[Dict]]] = None, @@ -894,12 +1022,12 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that # will be sent over the wire - self.put_json = simple_async_mock() - hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment] + self.put_json = AsyncMock() + hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign] # Mock out application services, and allow defining our own in tests self._services: List[ApplicationService] = [] - self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment] + self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign] return_value=self._services ) @@ -1000,8 +1128,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: # Mock the ApplicationServiceScheduler's _TransactionController's send method so that # we can track what's going out - self.send_mock = simple_async_mock() - hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method. + self.send_mock = AsyncMock() + hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method. # Define an application service for the tests self._service_token = "VERYSECRET" diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 036dbbc45b..413ff8795b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock import pymacaroons @@ -25,7 +25,6 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): @@ -166,8 +165,8 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.large_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.large_number_of_users ) self.get_failure( @@ -177,8 +176,8 @@ class AuthTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.large_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.large_number_of_users ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) @@ -191,8 +190,8 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._limit_usage_by_mau = True # Set the server to be at the edge of too many users. - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.auth_blocking._max_mau_value) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.auth_blocking._max_mau_value ) # If not in monthly active cohort @@ -208,8 +207,8 @@ class AuthTestCase(unittest.HomeserverTestCase): self.assertIsNone(self.token_login(token)) # If in monthly active cohort - self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.clock.time_msec()) + self.hs.get_datastores().main.user_last_seen_monthly_active = AsyncMock( + return_value=self.clock.time_msec() ) self.get_success( self.auth_handler.create_access_token_for_user_id( @@ -224,8 +223,8 @@ class AuthTestCase(unittest.HomeserverTestCase): def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.small_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.small_number_of_users ) # Ensure does not raise exception self.get_success( @@ -234,8 +233,8 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.hs.get_datastores().main.get_monthly_active_count = Mock( - return_value=make_awaitable(self.small_number_of_users) + self.hs.get_datastores().main.get_monthly_active_count = AsyncMock( + return_value=self.small_number_of_users ) token = self.get_success( self.auth_handler.create_login_token_for_user_id(self.user1) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index 63aad0d10c..13e2cd153a 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # These are a few constants that are used as config parameters in the tests. @@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("test_user", {}) request = _mock_request() @@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. cas_response = CasResponse("test_user", {}) @@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] cas_response = CasResponse("föö", {}) request = _mock_request() @@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. cas_response = CasResponse("test_user", {}) @@ -198,6 +197,23 @@ class CasHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) + @override_config({"cas_config": {"enable_registration": False}}) + def test_map_cas_user_does_not_register_new_user(self) -> None: + """Ensures new users are not registered if the enabled registration flag is disabled.""" + + # stub out the auth handler + auth_handler = self.hs.get_auth_handler() + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] + + cas_response = CasResponse("test_user", {}) + request = _mock_request() + self.get_success( + self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") + ) + + # check that the auth handler was not called as expected + auth_handler.complete_sso_login.assert_not_called() + def _mock_request() -> Mock: """Returns a mock which will stand in as a SynapseRequest""" diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ee48f9e546..d4ed068357 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -17,19 +17,22 @@ from typing import Optional from unittest import mock +from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError from synapse.appservice import ApplicationService from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler +from synapse.rest import admin +from synapse.rest.client import devices, login, register from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock +from synapse.util.task_scheduler import TaskScheduler from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config user1 = "@boris:aaa" @@ -38,16 +41,16 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.Mock() + self.appservice_api = mock.AsyncMock() hs = self.setup_test_homeserver( "server", - federation_http_client=None, application_service_api=self.appservice_api, ) handler = hs.get_device_handler() assert isinstance(handler, DeviceHandler) self.handler = handler self.store = hs.get_datastores().main + self.device_message_handler = hs.get_device_message_handler() return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -121,50 +124,50 @@ class DeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(3, len(res)) device_map = {d["device_id"]: d for d in res} - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "xyz", "display_name": "display 0", "last_seen_ip": None, "last_seen_ts": None, - }, - device_map["xyz"], + }.items(), + device_map["xyz"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "fco", "display_name": "display 1", "last_seen_ip": "ip1", "last_seen_ts": 1000000, - }, - device_map["fco"], + }.items(), + device_map["fco"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, - device_map["abc"], + }.items(), + device_map["abc"].items(), ) def test_get_device(self) -> None: self._record_users() res = self.get_success(self.handler.get_device(user1, "abc")) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user1, "device_id": "abc", "display_name": "display 2", "last_seen_ip": "ip3", "last_seen_ts": 3000000, - }, - res, + }.items(), + res.items(), ) def test_delete_device(self) -> None: @@ -210,6 +213,51 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) self.assertIsNone(res) + def test_delete_device_and_big_device_inbox(self) -> None: + """Check that deleting a big device inbox is staged and batched asynchronously.""" + DEVICE_ID = "abc" + sender = "@sender:" + self.hs.hostname + receiver = "@receiver:" + self.hs.hostname + self._record_user(sender, DEVICE_ID, DEVICE_ID) + self._record_user(receiver, DEVICE_ID, DEVICE_ID) + + # queue a bunch of messages in the inbox + requester = create_requester(sender, device_id=DEVICE_ID) + for i in range(DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT + 10): + self.get_success( + self.device_message_handler.send_device_message( + requester, "message_type", {receiver: {"*": {"val": i}}} + ) + ) + + # delete the device + self.get_success(self.handler.delete_devices(receiver, [DEVICE_ID])) + + # messages should be deleted up to DEVICE_MSGS_DELETE_BATCH_LIMIT straight away + res = self.get_success( + self.store.db_pool.simple_select_list( + table="device_inbox", + keyvalues={"user_id": receiver}, + retcols=("user_id", "device_id", "stream_id"), + desc="get_device_id_from_device_inbox", + ) + ) + self.assertEqual(10, len(res)) + + # wait for the task scheduler to do a second delete pass + self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) + + # remaining messages should now be deleted + res = self.get_success( + self.store.db_pool.simple_select_list( + table="device_inbox", + keyvalues={"user_id": receiver}, + retcols=("user_id", "device_id", "stream_id"), + desc="get_device_id_from_device_inbox", + ) + ) + self.assertEqual(0, len(res)) + def test_update_device(self) -> None: self._record_users() @@ -373,13 +421,11 @@ class DeviceTestCase(unittest.HomeserverTestCase): ) # Setup a response. - self.appservice_api.query_keys.return_value = make_awaitable( - { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} - } + self.appservice_api.query_keys.return_value = { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} } - ) + } # Request all devices. res = self.get_success( @@ -400,13 +446,22 @@ class DeviceTestCase(unittest.HomeserverTestCase): class DehydrationTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets_for_client_rest_resource, + login.register_servlets, + register.register_servlets, + devices.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver("server", federation_http_client=None) + hs = self.setup_test_homeserver("server") handler = hs.get_device_handler() assert isinstance(handler, DeviceHandler) self.handler = handler + self.message_handler = hs.get_device_message_handler() self.registration = hs.get_registration_handler() self.auth = hs.get_auth() + self.auth_handler = hs.get_auth_handler() self.store = hs.get_datastores().main return hs @@ -419,6 +474,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase): stored_dehydrated_device_id = self.get_success( self.handler.store_dehydrated_device( user_id=user_id, + device_id=None, device_data={"device_data": {"foo": "bar"}}, initial_device_display_name="dehydrated device", ) @@ -432,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) # Create a new login for the user and dehydrated the device - device_id, access_token, _expiration_time, _refresh_token = self.get_success( + device_id, access_token, _expiration_time, refresh_token = self.get_success( self.registration.register_device( user_id=user_id, device_id=None, initial_display_name="new device", + should_issue_refresh_token=True, ) ) @@ -467,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase): self.assertEqual(user_info.device_id, retrieved_device_id) + # make sure the user device has the refresh token + assert refresh_token is not None + self.get_success( + self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000) + ) + # make sure the device has the display name that was set from the login res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) @@ -482,3 +545,89 @@ class DehydrationTestCase(unittest.HomeserverTestCase): ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id)) self.assertIsNone(ret) + + @unittest.override_config( + {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}} + ) + def test_dehydrate_v2_and_fetch_events(self) -> None: + user_id = "@boris:server" + + self.get_success(self.store.register_user(user_id, "foobar")) + + # First check if we can store and fetch a dehydrated device + stored_dehydrated_device_id = self.get_success( + self.handler.store_dehydrated_device( + user_id=user_id, + device_id=None, + device_data={"device_data": {"foo": "bar"}}, + initial_device_display_name="dehydrated device", + ) + ) + + device_info = self.get_success( + self.handler.get_dehydrated_device(user_id=user_id) + ) + assert device_info is not None + retrieved_device_id, device_data = device_info + self.assertEqual(retrieved_device_id, stored_dehydrated_device_id) + self.assertEqual(device_data, {"device_data": {"foo": "bar"}}) + + # Create a new login for the user + device_id, access_token, _expiration_time, _refresh_token = self.get_success( + self.registration.register_device( + user_id=user_id, + device_id=None, + initial_display_name="new device", + ) + ) + + requester = create_requester(user_id, device_id=device_id) + + # Fetching messages for a non-existing device should return an error + self.get_failure( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id="not the right device ID", + since_token=None, + limit=10, + ), + SynapseError, + ) + + # Send a message to the dehydrated device + ensureDeferred( + self.message_handler.send_device_message( + requester=requester, + message_type="test.message", + messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}}, + ) + ) + self.pump() + + # Fetch the message of the dehydrated device + res = self.get_success( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=stored_dehydrated_device_id, + since_token=None, + limit=10, + ) + ) + + self.assertTrue(len(res["next_batch"]) > 1) + self.assertEqual(len(res["events"]), 1) + self.assertEqual(res["events"][0]["content"]["body"], "foo") + + # Fetch the message of the dehydrated device again, which should return + # the same message as it has not been deleted + res = self.get_success( + self.message_handler.get_events_for_dehydrated_device( + requester=requester, + device_id=stored_dehydrated_device_id, + since_token=None, + limit=10, + ) + ) + self.assertTrue(len(res["next_batch"]) > 1) + self.assertEqual(len(res["events"]), 1) + self.assertEqual(res["events"][0]["content"]["body"], "foo") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 90aec484c4..367d94eca3 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -27,14 +27,13 @@ from synapse.types import JsonDict, RoomAlias, create_requester from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = Mock() + self.mock_federation = AsyncMock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -73,9 +72,10 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self) -> None: - self.mock_federation.make_query.return_value = make_awaitable( - {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} - ) + self.mock_federation.make_query.return_value = { + "room_id": "!8765qwer:test", + "servers": ["test", "remote"], + } result = self.get_success(self.handler.get_association(self.remote_room)) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 72d0584061..c5556f2844 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -13,7 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable +from typing import Dict, Iterable from unittest import mock from parameterized import parameterized @@ -27,17 +27,16 @@ from synapse.appservice import ApplicationService from synapse.handlers.device import DeviceHandler from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict +from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.appservice_api = mock.Mock() + self.appservice_api = mock.AsyncMock() return self.setup_test_homeserver( federation_client=mock.Mock(), application_service_api=self.appservice_api ) @@ -45,6 +44,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main + self.requester = UserID.from_string(f"@test_requester:{self.hs.hostname}") def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" @@ -161,6 +161,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res2 = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -206,6 +207,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -225,6 +227,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -274,6 +277,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -286,6 +290,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -307,6 +312,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -348,6 +354,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -370,6 +377,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -792,29 +800,27 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_client_keys = mock.Mock( # type: ignore[assignment] - return_value=make_awaitable( - { - "device_keys": {remote_user_id: {}}, - "master_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - }, - "self_signing_keys": { - remote_user_id: { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, - } + self.hs.get_federation_client().query_client_keys = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "device_keys": {remote_user_id: {}}, + "master_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, }, - } - ) + }, + "self_signing_keys": { + remote_user_id: { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + + remote_self_signing_key: remote_self_signing_key + }, + } + }, + } ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -865,34 +871,29 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. - self.store.get_rooms_for_user = mock.Mock( - return_value=make_awaitable({"some_room_id"}) - ) + self.store.get_rooms_for_user = mock.AsyncMock(return_value={"some_room_id"}) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" - self.hs.get_federation_client().query_user_devices = mock.Mock( # type: ignore[assignment] - return_value=make_awaitable( - { + self.hs.get_federation_client().query_user_devices = mock.AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key }, - } - ) + }, + } ) e2e_handler = self.hs.get_e2e_keys_handler() @@ -978,20 +979,20 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): mock_get_rooms = mock.patch.object( self.store, "get_rooms_for_user", - new_callable=mock.MagicMock, - return_value=make_awaitable(["some_room_id"]), + new_callable=mock.AsyncMock, + return_value=["some_room_id"], ) mock_get_users = mock.patch.object( self.store, "get_users_server_still_shares_room_with", - new_callable=mock.MagicMock, - return_value=make_awaitable({remote_user_id}), + new_callable=mock.AsyncMock, + return_value={remote_user_id}, ) mock_request = mock.patch.object( self.hs.get_federation_client(), "query_user_devices", - new_callable=mock.MagicMock, - return_value=make_awaitable(response_body), + new_callable=mock.AsyncMock, + return_value=response_body, ) with mock_get_rooms, mock_get_users, mock_request as mocked_federation_request: @@ -1051,8 +1052,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) # Setup a response, but only for device 2. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_2: otk}}, [(local_user, device_id_1, "alg1", 1)]) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_2: otk}}, + [(local_user, device_id_1, "alg1", 1)], ) # we shouldn't have any unused fallback keys yet @@ -1080,6 +1082,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}, device_id_2: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=False, ) @@ -1117,14 +1120,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) # Setup a response. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: {**as_otk, **as_fallback_key}}}, []) - ) + response: Dict[str, Dict[str, Dict[str, JsonDict]]] = { + local_user: {device_id_1: {**as_otk, **as_fallback_key}} + } + self.appservice_api.claim_client_keys.return_value = (response, []) # Claim OTKs, which will ask the appservice and do nothing else. claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1160,8 +1165,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.assertEqual(fallback_res, ["alg1"]) # The appservice will return only the OTK. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: as_otk}}, []) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_1: as_otk}}, + [], ) # Claim OTKs, which should return the OTK from the appservice and the @@ -1169,6 +1175,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1202,6 +1209,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1221,14 +1229,16 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.assertEqual(fallback_res, ["alg1"]) # Finally, return only the fallback key from the appservice. - self.appservice_api.claim_client_keys.return_value = make_awaitable( - ({local_user: {device_id_1: as_fallback_key}}, []) + self.appservice_api.claim_client_keys.return_value = ( + {local_user: {device_id_1: as_fallback_key}}, + [], ) # Claim OTKs, which will return only the fallback key from the database. claim_res = self.get_success( self.handler.claim_one_time_keys( {local_user: {device_id_1: {"alg1": 1}}}, + self.requester, timeout=None, always_include_fallback_keys=True, ) @@ -1336,13 +1346,11 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) # Setup a response. - self.appservice_api.query_keys.return_value = make_awaitable( - { - "device_keys": { - local_user: {device_2: device_key_2b, device_3: device_key_3} - } + self.appservice_api.query_keys.return_value = { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} } - ) + } # Request all devices. res = self.get_success(self.handler.query_local_devices({local_user: None})) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index bf0862ed54..4fc0742413 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -14,7 +14,7 @@ import logging from typing import Collection, Optional, cast from unittest import TestCase -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor @@ -40,7 +40,7 @@ from synapse.util import Clock from synapse.util.stringutils import random_string from tests import unittest -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection logger = logging.getLogger(__name__) @@ -57,7 +57,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver(federation_http_client=None) + hs = self.setup_test_homeserver() self.handler = hs.get_federation_handler() self.store = hs.get_datastores().main return hs @@ -262,7 +262,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): if (ev.type, ev.state_key) in {("m.room.create", ""), ("m.room.member", remote_server_user_id)} ] - for _ in range(0, 8): + for _ in range(8): event = make_event_from_dict( self.add_hashes_and_signatures_from_other_server( { @@ -370,15 +370,15 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): # We mock out the FederationClient.backfill method, to pretend that a remote # server has returned our fake event. - federation_client_backfill_mock = Mock(return_value=make_awaitable([event])) - self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[assignment] + federation_client_backfill_mock = AsyncMock(return_value=[event]) + self.hs.get_federation_client().backfill = federation_client_backfill_mock # type: ignore[method-assign] # We also mock the persist method with a side effect of itself. This allows us # to track when it has been called while preserving its function. persist_events_and_notify_mock = Mock( side_effect=self.hs.get_federation_event_handler().persist_events_and_notify ) - self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[assignment] + self.hs.get_federation_event_handler().persist_events_and_notify = ( # type: ignore[method-assign] persist_events_and_notify_mock ) @@ -631,33 +631,29 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase): }, RoomVersions.V10, ) - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - "example.com", - membership_event, - RoomVersions.V10, - ) + mock_make_membership_event = AsyncMock( + return_value=( + "example.com", + membership_event, + RoomVersions.V10, ) ) - mock_send_join = Mock( - return_value=make_awaitable( - SendJoinResult( - membership_event, - "example.com", - state=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - auth_chain=[ - EVENT_CREATE, - EVENT_CREATOR_MEMBERSHIP, - EVENT_INVITATION_MEMBERSHIP, - ], - partial_state=True, - servers_in_room={"example.com"}, - ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + membership_event, + "example.com", + state=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + auth_chain=[ + EVENT_CREATE, + EVENT_CREATOR_MEMBERSHIP, + EVENT_INVITATION_MEMBERSHIP, + ], + partial_state=True, + servers_in_room={"example.com"}, ) ) diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index c067e5bfe3..70e6a7e142 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -35,7 +35,7 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): @@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): self.mock_federation_transport_client = mock.Mock( spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"] ) + self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock() + self.mock_federation_transport_client.get_room_state = mock.AsyncMock() + self.mock_federation_transport_client.get_event = mock.AsyncMock() + self.mock_federation_transport_client.backfill = mock.AsyncMock() return super().setup_test_homeserver( federation_transport_client=self.mock_federation_transport_client ) @@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) # we expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = ( - make_awaitable( - { - "pdu_ids": [e.event_id for e in state_at_prev_event], - "auth_chain_ids": [], - } - ) - ) + self.mock_federation_transport_client.get_room_state_ids.return_value = { + "pdu_ids": [e.event_id for e in state_at_prev_event], + "auth_chain_ids": [], + } # we also expect an outbound request to /state self.mock_federation_transport_client.get_room_state.return_value = ( - make_awaitable( - StateRequestResponse(auth_events=[], state=state_at_prev_event) - ) + StateRequestResponse(auth_events=[], state=state_at_prev_event) ) # we have to bump the clock a bit, to keep the retry logic in @@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): room_version = self.get_success(main_store.get_room_version(room_id)) # We expect an outbound request to /state_ids, so stub that out - self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable( - { - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - "pdu_ids": [], - "auth_chain_ids": [], - } - ) + self.mock_federation_transport_client.get_room_state_ids.return_value = { + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + "pdu_ids": [], + "auth_chain_ids": [], + } + # We also expect an outbound request to /state - self.mock_federation_transport_client.get_room_state.return_value = make_awaitable( - StateRequestResponse( - # Mimic the other server not knowing about the state at all. - # We want to cause Synapse to throw an error (`Unable to get - # missing prev_event $fake_prev_event`) and fail to backfill - # the pulled event. - auth_events=[], - state=[], - ) + self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse( + # Mimic the other server not knowing about the state at all. + # We want to cause Synapse to throw an error (`Unable to get + # missing prev_event $fake_prev_event`) and fail to backfill + # the pulled event. + auth_events=[], + state=[], ) pulled_event = make_event_from_dict( @@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): ) # We expect an outbound request to /backfill, so stub that out - self.mock_federation_transport_client.backfill.return_value = make_awaitable( - { - "origin": self.OTHER_SERVER_NAME, - "origin_server_ts": 123, - "pdus": [ - # This is one of the important aspects of this test: we include - # `pulled_event_without_signatures` so it fails the signature check - # when we filter down the backfill response down to events which - # have valid signatures in - # `_check_sigs_and_hash_for_pulled_events_and_fetch` - pulled_event_without_signatures.get_pdu_json(), - # Then later when we process this valid signature event, when we - # fetch the missing `prev_event`s, we want to make sure that we - # backoff and don't try and fetch `pulled_event_without_signatures` - # again since we know it just had an invalid signature. - pulled_event.get_pdu_json(), - ], - } - ) + self.mock_federation_transport_client.backfill.return_value = { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + # This is one of the important aspects of this test: we include + # `pulled_event_without_signatures` so it fails the signature check + # when we filter down the backfill response down to events which + # have valid signatures in + # `_check_sigs_and_hash_for_pulled_events_and_fetch` + pulled_event_without_signatures.get_pdu_json(), + # Then later when we process this valid signature event, when we + # fetch the missing `prev_event`s, we want to make sure that we + # backoff and don't try and fetch `pulled_event_without_signatures` + # again since we know it just had an invalid signature. + pulled_event.get_pdu_json(), + ], + } # Keep track of the count and make sure we don't make any of these requests event_endpoint_requested_count = 0 @@ -664,6 +657,99 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): StoreError, ) + def test_backfill_process_previously_failed_pull_attempt_event_in_the_background( + self, + ) -> None: + """ + Sanity check that events are still processed even if it is in the background + for events that already have failed pull attempts. + """ + OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}" + main_store = self.hs.get_datastores().main + + # Create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(main_store.get_room_version(room_id)) + + # Allow the remote user to send state events + self.helper.send_state( + room_id, + "m.room.power_levels", + {"events_default": 0, "state_default": 0}, + tok=tok, + ) + + # Add the remote user to the room + member_event = self.get_success( + event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") + ) + + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) + + auth_event_ids = [ + initial_state_map[("m.room.create", "")], + initial_state_map[("m.room.power_levels", "")], + member_event.event_id, + ] + + # Create a regular event that should process + pulled_event = make_event_from_dict( + self.add_hashes_and_signatures_from_other_server( + { + "type": "test_regular_type", + "room_id": room_id, + "sender": OTHER_USER, + "prev_events": [ + member_event.event_id, + ], + "auth_events": auth_event_ids, + "origin_server_ts": 1, + "depth": 12, + "content": {"body": "pulled_event"}, + } + ), + room_version, + ) + + # Record a failed pull attempt for this event which will cause us to backfill it + # in the background from here on out. + self.get_success( + main_store.record_event_failed_pull_attempt( + room_id, pulled_event.event_id, "fake cause" + ) + ) + + # We expect an outbound request to /backfill, so stub that out + self.mock_federation_transport_client.backfill.return_value = { + "origin": self.OTHER_SERVER_NAME, + "origin_server_ts": 123, + "pdus": [ + pulled_event.get_pdu_json(), + ], + } + + # The function under test: try to backfill and process the pulled event + with LoggingContext("test"): + self.get_success( + self.hs.get_federation_event_handler().backfill( + self.OTHER_SERVER_NAME, + room_id, + limit=1, + extremities=["$some_extremity"], + ) + ) + + # Ensure `run_as_background_process(...)` has a chance to run (essentially + # `wait_for_background_processes()`) + self.reactor.pump((0.1,)) + + # Make sure we processed and persisted the pulled event + self.get_success(main_store.get_event(pulled_event.event_id, allow_none=False)) + def test_process_pulled_event_with_rejected_missing_state(self) -> None: """Ensure that we correctly handle pulled events with missing state containing a rejected state event diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py index 9691d66b48..1c5897c84e 100644 --- a/tests/handlers/test_message.py +++ b/tests/handlers/test_message.py @@ -46,18 +46,11 @@ class EventCreationTestCase(unittest.HomeserverTestCase): self._persist_event_storage_controller = persistence self.user_id = self.register_user("tester", "foobar") - self.access_token = self.login("tester", "foobar") - self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token) - - info = self.get_success( - self.hs.get_datastores().main.get_user_by_access_token( - self.access_token, - ) - ) - assert info is not None - self.token_id = info.token_id + device_id = "dev-1" + access_token = self.login("tester", "foobar", device_id=device_id) + self.room_id = self.helper.create_room_as(self.user_id, tok=access_token) - self.requester = create_requester(self.user_id, access_token_id=self.token_id) + self.requester = create_requester(self.user_id, device_id=device_id) def _create_and_persist_member_event(self) -> Tuple[EventBase, EventContext]: # Create a member event we can use as an auth_event diff --git a/tests/handlers/test_oauth_delegation.py b/tests/handlers/test_oauth_delegation.py new file mode 100644 index 0000000000..a72ecfdc97 --- /dev/null +++ b/tests/handlers/test_oauth_delegation.py @@ -0,0 +1,687 @@ +# Copyright 2022 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from http import HTTPStatus +from typing import Any, Dict, Union +from unittest.mock import ANY, AsyncMock, Mock +from urllib.parse import parse_qs + +from signedjson.key import ( + encode_verify_key_base64, + generate_signing_key, + get_verify_key, +) +from signedjson.sign import sign_json + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.errors import ( + AuthError, + Codes, + InvalidClientTokenError, + OAuthInsufficientScopeError, + SynapseError, +) +from synapse.rest import admin +from synapse.rest.client import account, devices, keys, login, logout, register +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.test_utils import FakeResponse, get_awaitable_result +from tests.unittest import HomeserverTestCase, skip_unless +from tests.utils import mock_getRawHeaders + +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + + +# These are a few constants that are used as config parameters in the tests. +SERVER_NAME = "test" +ISSUER = "https://issuer/" +CLIENT_ID = "test-client-id" +CLIENT_SECRET = "test-client-secret" +BASE_URL = "https://synapse/" +SCOPES = ["openid"] + +AUTHORIZATION_ENDPOINT = ISSUER + "authorize" +TOKEN_ENDPOINT = ISSUER + "token" +USERINFO_ENDPOINT = ISSUER + "userinfo" +WELL_KNOWN = ISSUER + ".well-known/openid-configuration" +JWKS_URI = ISSUER + ".well-known/jwks.json" +INTROSPECTION_ENDPOINT = ISSUER + "introspect" + +SYNAPSE_ADMIN_SCOPE = "urn:synapse:admin:*" +MATRIX_USER_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:*" +MATRIX_GUEST_SCOPE = "urn:matrix:org.matrix.msc2967.client:api:guest" +MATRIX_DEVICE_SCOPE_PREFIX = "urn:matrix:org.matrix.msc2967.client:device:" +DEVICE = "AABBCCDD" +MATRIX_DEVICE_SCOPE = MATRIX_DEVICE_SCOPE_PREFIX + DEVICE +SUBJECT = "abc-def-ghi" +USERNAME = "test-user" +USER_ID = "@" + USERNAME + ":" + SERVER_NAME + + +async def get_json(url: str) -> JsonDict: + # Mock get_json calls to handle jwks & oidc discovery endpoints + if url == WELL_KNOWN: + # Minimal discovery document, as defined in OpenID.Discovery + # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata + return { + "issuer": ISSUER, + "authorization_endpoint": AUTHORIZATION_ENDPOINT, + "token_endpoint": TOKEN_ENDPOINT, + "jwks_uri": JWKS_URI, + "userinfo_endpoint": USERINFO_ENDPOINT, + "introspection_endpoint": INTROSPECTION_ENDPOINT, + "response_types_supported": ["code"], + "subject_types_supported": ["public"], + "id_token_signing_alg_values_supported": ["RS256"], + } + elif url == JWKS_URI: + return {"keys": []} + + return {} + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class MSC3861OAuthDelegation(HomeserverTestCase): + servlets = [ + account.register_servlets, + devices.register_servlets, + keys.register_servlets, + register.register_servlets, + login.register_servlets, + logout.register_servlets, + admin.register_servlets, + ] + + def default_config(self) -> Dict[str, Any]: + config = super().default_config() + config["public_baseurl"] = BASE_URL + config["disable_registration"] = True + config["experimental_features"] = { + "msc3861": { + "enabled": True, + "issuer": ISSUER, + "client_id": CLIENT_ID, + "client_auth_method": "client_secret_post", + "client_secret": CLIENT_SECRET, + "admin_token": "admin_token_value", + } + } + return config + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.http_client = Mock(spec=["get_json"]) + self.http_client.get_json.side_effect = get_json + self.http_client.user_agent = b"Synapse Test" + + hs = self.setup_test_homeserver(proxied_http_client=self.http_client) + + self.auth = hs.get_auth() + + return hs + + def _assertParams(self) -> None: + """Assert that the request parameters are correct.""" + params = parse_qs(self.http_client.request.call_args[1]["data"].decode("utf-8")) + self.assertEqual(params["token"], ["mockAccessToken"]) + self.assertEqual(params["client_id"], [CLIENT_ID]) + self.assertEqual(params["client_secret"], [CLIENT_SECRET]) + + def test_inactive_token(self) -> None: + """The handler should return a 403 where the token is inactive.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={"active": False}, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + + def test_active_no_scope(self) -> None: + """The handler should return a 403 where no scope is given.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={"active": True}, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + + def test_active_user_no_subject(self) -> None: + """The handler should return a 500 when no subject is present.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])}, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + + def test_active_no_user_scope(self) -> None: + """The handler should return a 500 when no subject is present.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_DEVICE_SCOPE]), + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + + def test_active_admin_not_user(self) -> None: + """The handler should raise when the scope has admin right but not user.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([SYNAPSE_ADMIN_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), InvalidClientTokenError) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + + def test_active_admin(self) -> None: + """The handler should return a requester with admin rights.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, False) + self.assertEqual(requester.device_id, None) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), True + ) + + def test_active_admin_highest_privilege(self) -> None: + """The handler should resolve to the most permissive scope.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [SYNAPSE_ADMIN_SCOPE, MATRIX_USER_SCOPE, MATRIX_GUEST_SCOPE] + ), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, False) + self.assertEqual(requester.device_id, None) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), True + ) + + def test_active_user(self) -> None: + """The handler should return a requester with normal user rights.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, False) + self.assertEqual(requester.device_id, None) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), False + ) + + def test_active_user_with_device(self) -> None: + """The handler should return a requester with normal user rights and a device ID.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, False) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), False + ) + self.assertEqual(requester.device_id, DEVICE) + + def test_multiple_devices(self) -> None: + """The handler should raise an error if multiple devices are found in the scope.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join( + [ + MATRIX_USER_SCOPE, + f"{MATRIX_DEVICE_SCOPE_PREFIX}AABBCC", + f"{MATRIX_DEVICE_SCOPE_PREFIX}DDEEFF", + ] + ), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + self.get_failure(self.auth.get_user_by_req(request), AuthError) + + def test_active_guest_not_allowed(self) -> None: + """The handler should return an insufficient scope error.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + error = self.get_failure( + self.auth.get_user_by_req(request), OAuthInsufficientScopeError + ) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual( + getattr(error.value, "headers", {})["WWW-Authenticate"], + 'Bearer error="insufficient_scope", scope="urn:matrix:org.matrix.msc2967.client:api:*"', + ) + + def test_active_guest_allowed(self) -> None: + """The handler should return a requester with guest user rights and a device ID.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_GUEST_SCOPE, MATRIX_DEVICE_SCOPE]), + "username": USERNAME, + }, + ) + ) + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success( + self.auth.get_user_by_req(request, allow_guest=True) + ) + self.http_client.get_json.assert_called_once_with(WELL_KNOWN) + self.http_client.request.assert_called_once_with( + method="POST", uri=INTROSPECTION_ENDPOINT, data=ANY, headers=ANY + ) + self._assertParams() + self.assertEqual(requester.user.to_string(), "@%s:%s" % (USERNAME, SERVER_NAME)) + self.assertEqual(requester.is_guest, True) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), False + ) + self.assertEqual(requester.device_id, DEVICE) + + def test_unavailable_introspection_endpoint(self) -> None: + """The handler should return an internal server error.""" + request = Mock(args={}) + request.args[b"access_token"] = [b"mockAccessToken"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + + # The introspection endpoint is returning an error. + self.http_client.request = AsyncMock( + return_value=FakeResponse(code=500, body=b"Internal Server Error") + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint request fails. + self.http_client.request = AsyncMock(side_effect=Exception()) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint does not return a JSON object. + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, payload=["this is an array", "not an object"] + ) + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + # The introspection endpoint does not return valid JSON. + self.http_client.request = AsyncMock( + return_value=FakeResponse(code=200, body=b"this is not valid JSON") + ) + error = self.get_failure(self.auth.get_user_by_req(request), SynapseError) + self.assertEqual(error.value.code, 503) + + def make_device_keys(self, user_id: str, device_id: str) -> JsonDict: + # We only generate a master key to simplify the test. + master_signing_key = generate_signing_key(device_id) + master_verify_key = encode_verify_key_base64(get_verify_key(master_signing_key)) + + return { + "master_key": sign_json( + { + "user_id": user_id, + "usage": ["master"], + "keys": {"ed25519:" + master_verify_key: master_verify_key}, + }, + user_id, + master_signing_key, + ), + } + + def test_cross_signing(self) -> None: + """Try uploading device keys with OAuth delegation enabled.""" + + self.http_client.request = AsyncMock( + return_value=FakeResponse.json( + code=200, + payload={ + "active": True, + "sub": SUBJECT, + "scope": " ".join([MATRIX_USER_SCOPE, MATRIX_DEVICE_SCOPE]), + "username": USERNAME, + }, + ) + ) + keys_upload_body = self.make_device_keys(USER_ID, DEVICE) + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys_upload_body, + access_token="mockAccessToken", + ) + + self.assertEqual(channel.code, 200, channel.json_body) + + channel = self.make_request( + "POST", + "/_matrix/client/v3/keys/device_signing/upload", + keys_upload_body, + access_token="mockAccessToken", + ) + + self.assertEqual(channel.code, HTTPStatus.NOT_IMPLEMENTED, channel.json_body) + + def expect_unauthorized( + self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + ) -> None: + channel = self.make_request(method, path, content, shorthand=False) + + self.assertEqual(channel.code, 401, channel.json_body) + + def expect_unrecognized( + self, method: str, path: str, content: Union[bytes, str, JsonDict] = "" + ) -> None: + channel = self.make_request(method, path, content) + + self.assertEqual(channel.code, 404, channel.json_body) + self.assertEqual( + channel.json_body["errcode"], Codes.UNRECOGNIZED, channel.json_body + ) + + def test_uia_endpoints(self) -> None: + """Test that endpoints that were removed in MSC2964 are no longer available.""" + + # This is just an endpoint that should remain visible (but requires auth): + self.expect_unauthorized("GET", "/_matrix/client/v3/devices") + + # This remains usable, but will require a uia scope: + self.expect_unauthorized( + "POST", "/_matrix/client/v3/keys/device_signing/upload" + ) + + def test_3pid_endpoints(self) -> None: + """Test that 3pid account management endpoints that were removed in MSC2964 are no longer available.""" + + # Remains and requires auth: + self.expect_unauthorized("GET", "/_matrix/client/v3/account/3pid") + self.expect_unauthorized( + "POST", + "/_matrix/client/v3/account/3pid/bind", + { + "client_secret": "foo", + "id_access_token": "bar", + "id_server": "foo", + "sid": "bar", + }, + ) + self.expect_unauthorized("POST", "/_matrix/client/v3/account/3pid/unbind", {}) + + # These are gone: + self.expect_unrecognized( + "POST", "/_matrix/client/v3/account/3pid" + ) # deprecated + self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/add") + self.expect_unrecognized("POST", "/_matrix/client/v3/account/3pid/delete") + self.expect_unrecognized( + "POST", "/_matrix/client/v3/account/3pid/email/requestToken" + ) + self.expect_unrecognized( + "POST", "/_matrix/client/v3/account/3pid/msisdn/requestToken" + ) + + def test_account_management_endpoints_removed(self) -> None: + """Test that account management endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized("POST", "/_matrix/client/v3/account/deactivate") + self.expect_unrecognized("POST", "/_matrix/client/v3/account/password") + self.expect_unrecognized( + "POST", "/_matrix/client/v3/account/password/email/requestToken" + ) + self.expect_unrecognized( + "POST", "/_matrix/client/v3/account/password/msisdn/requestToken" + ) + + def test_registration_endpoints_removed(self) -> None: + """Test that registration endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized( + "GET", "/_matrix/client/v1/register/m.login.registration_token/validity" + ) + # This is still available for AS registrations + # self.expect_unrecognized("POST", "/_matrix/client/v3/register") + self.expect_unrecognized("GET", "/_matrix/client/v3/register/available") + self.expect_unrecognized( + "POST", "/_matrix/client/v3/register/email/requestToken" + ) + self.expect_unrecognized( + "POST", "/_matrix/client/v3/register/msisdn/requestToken" + ) + + def test_session_management_endpoints_removed(self) -> None: + """Test that session management endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized("GET", "/_matrix/client/v3/login") + self.expect_unrecognized("POST", "/_matrix/client/v3/login") + self.expect_unrecognized("GET", "/_matrix/client/v3/login/sso/redirect") + self.expect_unrecognized("POST", "/_matrix/client/v3/logout") + self.expect_unrecognized("POST", "/_matrix/client/v3/logout/all") + self.expect_unrecognized("POST", "/_matrix/client/v3/refresh") + self.expect_unrecognized("GET", "/_matrix/static/client/login") + + def test_device_management_endpoints_removed(self) -> None: + """Test that device management endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized("POST", "/_matrix/client/v3/delete_devices") + self.expect_unrecognized("DELETE", "/_matrix/client/v3/devices/{DEVICE}") + + def test_openid_endpoints_removed(self) -> None: + """Test that OpenID id_token endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized( + "POST", "/_matrix/client/v3/user/{USERNAME}/openid/request_token" + ) + + def test_admin_api_endpoints_removed(self) -> None: + """Test that admin API endpoints that were removed in MSC2964 are no longer available.""" + self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens") + self.expect_unrecognized("POST", "/_synapse/admin/v1/registration_tokens/new") + self.expect_unrecognized("GET", "/_synapse/admin/v1/registration_tokens/abcd") + self.expect_unrecognized("PUT", "/_synapse/admin/v1/registration_tokens/abcd") + self.expect_unrecognized( + "DELETE", "/_synapse/admin/v1/registration_tokens/abcd" + ) + self.expect_unrecognized("POST", "/_synapse/admin/v1/reset_password/foo") + self.expect_unrecognized("POST", "/_synapse/admin/v1/users/foo/login") + self.expect_unrecognized("GET", "/_synapse/admin/v1/register") + self.expect_unrecognized("POST", "/_synapse/admin/v1/register") + self.expect_unrecognized("GET", "/_synapse/admin/v1/users/foo/admin") + self.expect_unrecognized("PUT", "/_synapse/admin/v1/users/foo/admin") + self.expect_unrecognized("POST", "/_synapse/admin/v1/account_validity/validity") + + def test_admin_token(self) -> None: + """The handler should return a requester with admin rights when admin_token is used.""" + self.http_client.request = AsyncMock( + return_value=FakeResponse.json(code=200, payload={"active": False}), + ) + + request = Mock(args={}) + request.args[b"access_token"] = [b"admin_token_value"] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + requester = self.get_success(self.auth.get_user_by_req(request)) + self.assertEqual( + requester.user.to_string(), "@%s:%s" % ("__oidc_admin", SERVER_NAME) + ) + self.assertEqual(requester.is_guest, False) + self.assertEqual(requester.device_id, None) + self.assertEqual( + get_awaitable_result(self.auth.is_server_admin(requester)), True + ) + + # There should be no call to the introspection endpoint + self.http_client.request.assert_not_called() diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 0a8bae54fb..e797aaae00 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,7 +13,7 @@ # limitations under the License. import os from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple -from unittest.mock import ANY, Mock, patch +from unittest.mock import ANY, AsyncMock, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons @@ -28,7 +28,7 @@ from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from synapse.util.stringutils import random_string -from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock +from tests.test_utils import FakeResponse, get_awaitable_result from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer from tests.unittest import HomeserverTestCase, override_config @@ -157,15 +157,15 @@ class OidcHandlerTestCase(HomeserverTestCase): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error # type: ignore[assignment] + sso_handler.render_error = self.render_error # type: ignore[method-assign] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 auth_handler = hs.get_auth_handler() # Mock the complete SSO login method. - self.complete_sso_login = simple_async_mock() - auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[assignment] + self.complete_sso_login = AsyncMock() + auth_handler.complete_sso_login = self.complete_sso_login # type: ignore[method-assign] return hs diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index aa91bc0a3d..11ec8c7f11 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -16,7 +16,9 @@ from http import HTTPStatus from typing import Any, Dict, List, Optional, Type, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock + +from twisted.test.proto_helpers import MemoryReactor import synapse from synapse.api.constants import LoginType @@ -24,11 +26,12 @@ from synapse.api.errors import Codes from synapse.handlers.account import AccountHandler from synapse.module_api import ModuleApi from synapse.rest.client import account, devices, login, logout, register +from synapse.server import HomeServer from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.server import FakeChannel -from tests.test_utils import make_awaitable from tests.unittest import override_config # Login flows we expect to appear in the list after the normal ones. @@ -162,10 +165,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): CALLBACK_USERNAME = "get_username_for_registration" CALLBACK_DISPLAYNAME = "get_displayname_for_registration" - def setUp(self) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: # we use a global mock device, so make sure we are starting with a clean slate mock_password_provider.reset_mock() - super().setUp() + + # The mock password provider doesn't register the users, so ensure they + # are registered first. + self.register_user("u", "not-the-tested-password") + self.register_user("user", "not-the-tested-password") @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) def test_password_only_auth_progiver_login_legacy(self) -> None: @@ -177,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -185,22 +194,12 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # login with mxid should work too - channel = self._send_password_login("@u:bz", "p") + channel = self._send_password_login("@u:test", "p") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertEqual("@u:bz", channel.json_body["user_id"]) - mock_password_provider.check_password.assert_called_once_with("@u:bz", "p") + self.assertEqual("@u:test", channel.json_body["user_id"]) + mock_password_provider.check_password.assert_called_once_with("@u:test", "p") mock_password_provider.reset_mock() - # try a weird username / pass. Honestly it's unclear what we *expect* to happen - # in these cases, but at least we can guard against the API changing - # unexpectedly - channel = self._send_password_login(" USER🙂NAME ", " pASS\U0001F622word ") - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertEqual("@ USER🙂NAME :test", channel.json_body["user_id"]) - mock_password_provider.check_password.assert_called_once_with( - "@ USER🙂NAME :test", " pASS😢word " - ) - @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) def test_password_only_auth_provider_ui_auth_legacy(self) -> None: self.password_only_auth_provider_ui_auth_test_body() @@ -208,18 +207,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): def password_only_auth_provider_ui_auth_test_body(self) -> None: """UI Auth should delegate correctly to the password provider""" - # create the user, otherwise access doesn't work - module_api = self.hs.get_module_api() - self.get_success(module_api.register_user("u")) - # log in twice, to get two devices - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -233,7 +228,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -247,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result) @@ -264,7 +259,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -307,7 +302,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -329,7 +324,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = make_awaitable(True) + mock_password_provider.check_password = AsyncMock(return_value=True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -346,7 +341,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = make_awaitable(False) + mock_password_provider.check_password = AsyncMock(return_value=False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -400,30 +395,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:bz", None) - ) + mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertEqual("@user:bz", channel.json_body["user_id"]) + self.assertEqual("@user:test", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} ) mock_password_provider.reset_mock() - # try a weird username. Again, it's unclear what we *expect* to happen - # in these cases, but at least we can guard against the API changing - # unexpectedly - mock_password_provider.check_auth.return_value = make_awaitable( - ("@ MALFORMED! :bz", None) - ) - channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") - self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertEqual("@ MALFORMED! :bz", channel.json_body["user_id"]) - mock_password_provider.check_auth.assert_called_once_with( - " USER🙂NAME ", "test.login_type", {"test_field": " abc "} - ) - @override_config(legacy_providers_config(LegacyCustomAuthProvider)) def test_custom_auth_provider_ui_auth_legacy(self) -> None: self.custom_auth_provider_ui_auth_test_body() @@ -464,9 +444,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:bz", None) - ) + mock_password_provider.check_auth = AsyncMock(return_value=("@user:test", None)) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -477,8 +455,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = make_awaitable( - ("@localuser:test", None) + mock_password_provider.check_auth = AsyncMock( + return_value=("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -495,14 +473,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self) -> None: - callback = Mock(return_value=make_awaitable(None)) + callback = AsyncMock(return_value=None) - mock_password_provider.check_auth.return_value = make_awaitable( - ("@user:bz", callback) + mock_password_provider.check_auth = AsyncMock( + return_value=("@user:test", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) - self.assertEqual("@user:bz", channel.json_body["user_id"]) + self.assertEqual("@user:test", channel.json_body["user_id"]) mock_password_provider.check_auth.assert_called_once_with( "u", "test.login_type", {"test_field": "y"} ) @@ -512,7 +490,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): call_args, call_kwargs = callback.call_args # should be one positional arg self.assertEqual(len(call_args), 1) - self.assertEqual(call_args[0]["user_id"], "@user:bz") + self.assertEqual(call_args[0]["user_id"], "@user:test") for p in ["user_id", "access_token", "device_id", "home_server"]: self.assertIn(p, call_args[0]) @@ -633,8 +611,8 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = make_awaitable( - ("@localuser:test", None) + mock_password_provider.check_auth = AsyncMock( + return_value=("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, HTTPStatus.OK, channel.result) @@ -852,11 +830,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): username: The username to use for the test. registration: Whether to test with registration URLs. """ - self.hs.get_identity_handler().send_threepid_validation = Mock( # type: ignore[assignment] - return_value=make_awaitable(0), + self.hs.get_identity_handler().send_threepid_validation = AsyncMock( # type: ignore[method-assign] + return_value=0 ) - m = Mock(return_value=make_awaitable(False)) + m = AsyncMock(return_value=False) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] self.register_user(username, "password") @@ -886,7 +864,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): m.assert_called_once_with("email", "foo@test.com", registration) - m = Mock(return_value=make_awaitable(True)) + m = AsyncMock(return_value=True) self.hs.get_password_auth_provider().is_3pid_allowed_callbacks = [m] channel = self.make_request( diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py index 19f5322317..41c8c44e02 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py @@ -21,11 +21,12 @@ from signedjson.key import generate_signing_key from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership, PresenceState -from synapse.api.presence import UserPresenceState +from synapse.api.presence import UserDevicePresenceState, UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events.builder import EventBuilder from synapse.federation.sender import FederationSender from synapse.handlers.presence import ( + BUSY_ONLINE_TIMEOUT, EXTERNAL_PROCESS_EXPIRY, FEDERATION_PING_INTERVAL, FEDERATION_TIMEOUT, @@ -38,6 +39,7 @@ from synapse.handlers.presence import ( from synapse.rest import admin from synapse.rest.client import room from synapse.server import HomeServer +from synapse.storage.database import LoggingDatabaseConnection from synapse.types import JsonDict, UserID, get_domain_from_id from synapse.util import Clock @@ -351,6 +353,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_idle_timer(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -361,8 +364,21 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNotNone(new_state) assert new_state is not None @@ -375,6 +391,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): presence state into unavailable. """ user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -385,8 +402,21 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNotNone(new_state) assert new_state is not None @@ -395,6 +425,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_sync_timeout(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -405,8 +436,21 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNotNone(new_state) assert new_state is not None @@ -415,6 +459,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_sync_online(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -425,9 +470,20 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) new_state = handle_timeout( - state, is_mine=True, syncing_user_ids={user_id}, now=now + state, + is_mine=True, + syncing_device_ids={(user_id, device_id)}, + user_devices={device_id: device_state}, + now=now, ) self.assertIsNotNone(new_state) @@ -437,6 +493,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_federation_ping(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -448,14 +505,28 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNotNone(new_state) self.assertEqual(state, new_state) def test_no_timeout(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" now = 5000000 state = UserPresenceState.default(user_id) @@ -465,8 +536,21 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_user_sync_ts=now, last_federation_update_ts=now, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNone(new_state) @@ -484,8 +568,9 @@ class PresenceTimeoutTestCase(unittest.TestCase): status_msg=status_msg, ) + # Note that this is a remote user so we do not have their device information. new_state = handle_timeout( - state, is_mine=False, syncing_user_ids=set(), now=now + state, is_mine=False, syncing_device_ids=set(), user_devices={}, now=now ) self.assertIsNotNone(new_state) @@ -495,6 +580,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_last_active(self) -> None: user_id = "@foo:bar" + device_id = "dev-1" status_msg = "I'm here!" now = 5000000 @@ -506,14 +592,150 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_federation_update_ts=now, status_msg=status_msg, ) + device_state = UserDevicePresenceState( + user_id=user_id, + device_id=device_id, + state=state.state, + last_active_ts=state.last_active_ts, + last_sync_ts=state.last_user_sync_ts, + ) - new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) + new_state = handle_timeout( + state, + is_mine=True, + syncing_device_ids=set(), + user_devices={device_id: device_state}, + now=now, + ) self.assertIsNotNone(new_state) self.assertEqual(state, new_state) +class PresenceHandlerInitTestCase(unittest.HomeserverTestCase): + def default_config(self) -> JsonDict: + config = super().default_config() + # Disable background tasks on this worker so that the PresenceHandler isn't + # loaded until we request it. + config["run_background_tasks_on"] = "other" + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user_id = f"@test:{self.hs.config.server.server_name}" + self.device_id = "dev-1" + + # Move the reactor to the initial time. + self.reactor.advance(1000) + now = self.clock.time_msec() + + main_store = hs.get_datastores().main + self.get_success( + main_store.update_presence( + [ + UserPresenceState( + user_id=self.user_id, + state=PresenceState.ONLINE, + last_active_ts=now, + last_federation_update_ts=now, + last_user_sync_ts=now, + status_msg=None, + currently_active=True, + ) + ] + ) + ) + + # Regenerate the preloaded presence information on PresenceStore. + def refill_presence(db_conn: LoggingDatabaseConnection) -> None: + main_store._presence_on_startup = main_store._get_active_presence(db_conn) + + self.get_success(main_store.db_pool.runWithConnection(refill_presence)) + + def test_restored_presence_idles(self) -> None: + """The presence state restored from the database should not persist forever.""" + + # Get the handler (which kicks off a bunch of timers). + presence_handler = self.hs.get_presence_handler() + + # Assert the user is online. + state = self.get_success( + presence_handler.get_state(UserID.from_string(self.user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + # Advance such that the user should timeout. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000) + self.reactor.pump([5]) + + # Check that the user is now offline. + state = self.get_success( + presence_handler.get_state(UserID.from_string(self.user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + @parameterized.expand( + [ + (PresenceState.BUSY, PresenceState.BUSY), + (PresenceState.ONLINE, PresenceState.ONLINE), + (PresenceState.UNAVAILABLE, PresenceState.ONLINE), + # Offline syncs don't update the state. + (PresenceState.OFFLINE, PresenceState.ONLINE), + ] + ) + @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) + def test_restored_presence_online_after_sync( + self, sync_state: str, expected_state: str + ) -> None: + """ + The presence state restored from the database should be overridden with sync after a timeout. + + Args: + sync_state: The presence state of the new sync. + expected_state: The expected presence right after the sync. + """ + + # Get the handler (which kicks off a bunch of timers). + presence_handler = self.hs.get_presence_handler() + + # Assert the user is online, as restored. + state = self.get_success( + presence_handler.get_state(UserID.from_string(self.user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + + # Advance slightly and sync. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2) + self.get_success( + presence_handler.user_syncing( + self.user_id, + self.device_id, + sync_state != PresenceState.OFFLINE, + sync_state, + ) + ) + + # Assert the user is in the expected state. + state = self.get_success( + presence_handler.get_state(UserID.from_string(self.user_id)) + ) + self.assertEqual(state.state, expected_state) + + # Advance such that the user's preloaded data times out, but not the new sync. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000 / 2) + self.reactor.pump([5]) + + # Check that the user is in the sync state (as the client is currently syncing still). + state = self.get_success( + presence_handler.get_state(UserID.from_string(self.user_id)) + ) + self.assertEqual(state.state, sync_state) + + class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): + user_id = "@test:server" + user_id_obj = UserID.from_string(user_id) + device_id = "dev-1" + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.presence_handler = hs.get_presence_handler() self.clock = hs.get_clock() @@ -522,62 +744,57 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): """Test that if an external process doesn't update the records for a while we time out their syncing users presence. """ - process_id = "1" - user_id = "@test:server" - # Notify handler that a user is now syncing. + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "synchrotron"} + ) + worker_presence_handler = worker_to_sync_against.get_presence_handler() + self.get_success( - self.presence_handler.update_external_syncs_row( - process_id, user_id, True, self.clock.time_msec() - ) + worker_presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ), + by=0.1, ) # Check that if we wait a while without telling the handler the user has # stopped syncing that their presence state doesn't get timed out. self.reactor.advance(EXTERNAL_PROCESS_EXPIRY / 2) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) self.assertEqual(state.state, PresenceState.ONLINE) # Check that if the external process timeout fires, then the syncing # user gets timed out self.reactor.advance(EXTERNAL_PROCESS_EXPIRY) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) self.assertEqual(state.state, PresenceState.OFFLINE) def test_user_goes_offline_by_timeout_status_msg_remain(self) -> None: """Test that if a user doesn't update the records for a while users presence goes `OFFLINE` because of timeout and `status_msg` remains. """ - user_id = "@test:server" status_msg = "I'm here!" # Mark user as online - self._set_presencestate_with_status_msg( - user_id, PresenceState.ONLINE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) # Check that if we wait a while without telling the handler the user has # stopped syncing that their presence state doesn't get timed out. self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.status_msg, status_msg) # Check that if the timeout fires, then the syncing user gets timed out self.reactor.advance(SYNC_ONLINE_TIMEOUT) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) # status_msg should remain even after going offline self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, status_msg) @@ -586,24 +803,19 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): """Test that if a user change presence manually to `OFFLINE` and no status is set, that `status_msg` is `None`. """ - user_id = "@test:server" status_msg = "I'm here!" # Mark user as online - self._set_presencestate_with_status_msg( - user_id, PresenceState.ONLINE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) # Mark user as offline self.get_success( self.presence_handler.set_state( - UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + self.user_id_obj, self.device_id, {"presence": PresenceState.OFFLINE} ) ) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) self.assertEqual(state.state, PresenceState.OFFLINE) self.assertEqual(state.status_msg, None) @@ -611,41 +823,31 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): """Test that if a user change presence manually to `OFFLINE` and a status is set, that `status_msg` appears. """ - user_id = "@test:server" status_msg = "I'm here!" # Mark user as online - self._set_presencestate_with_status_msg( - user_id, PresenceState.ONLINE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) # Mark user as offline - self._set_presencestate_with_status_msg( - user_id, PresenceState.OFFLINE, "And now here." - ) + self._set_presencestate_with_status_msg(PresenceState.OFFLINE, "And now here.") def test_user_reset_online_with_no_status(self) -> None: """Test that if a user set again the presence manually and no status is set, that `status_msg` is `None`. """ - user_id = "@test:server" status_msg = "I'm here!" # Mark user as online - self._set_presencestate_with_status_msg( - user_id, PresenceState.ONLINE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) # Mark user as online again self.get_success( self.presence_handler.set_state( - UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + self.user_id_obj, self.device_id, {"presence": PresenceState.ONLINE} ) ) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) # status_msg should remain even after going offline self.assertEqual(state.state, PresenceState.ONLINE) self.assertEqual(state.status_msg, None) @@ -654,33 +856,27 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): """Test that if a user set again the presence manually and status is `None`, that `status_msg` is `None`. """ - user_id = "@test:server" status_msg = "I'm here!" # Mark user as online - self._set_presencestate_with_status_msg( - user_id, PresenceState.ONLINE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) # Mark user as online and `status_msg = None` - self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + self._set_presencestate_with_status_msg(PresenceState.ONLINE, None) def test_set_presence_from_syncing_not_set(self) -> None: """Test that presence is not set by syncing if affect_presence is false""" - user_id = "@test:server" status_msg = "I'm here!" - self._set_presencestate_with_status_msg( - user_id, PresenceState.UNAVAILABLE, status_msg - ) + self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) self.get_success( - self.presence_handler.user_syncing(user_id, False, PresenceState.ONLINE) + self.presence_handler.user_syncing( + self.user_id, self.device_id, False, PresenceState.ONLINE + ) ) - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) # we should still be unavailable self.assertEqual(state.state, PresenceState.UNAVAILABLE) # and status message should still be the same @@ -688,50 +884,518 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): def test_set_presence_from_syncing_is_set(self) -> None: """Test that presence is set by syncing if affect_presence is true""" - user_id = "@test:server" status_msg = "I'm here!" - self._set_presencestate_with_status_msg( - user_id, PresenceState.UNAVAILABLE, status_msg + self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) + + self.get_success( + self.presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ) + ) + + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) + # we should now be online + self.assertEqual(state.state, PresenceState.ONLINE) + + @parameterized.expand( + # A list of tuples of 4 strings: + # + # * The presence state of device 1. + # * The presence state of device 2. + # * The expected user presence state after both devices have synced. + # * The expected user presence state after device 1 has idled. + # * The expected user presence state after device 2 has idled. + # * True to use workers, False a monolith. + [ + (*cases, workers) + for workers in (False, True) + for cases in [ + # If both devices have the same state, online should eventually idle. + # Otherwise, the state doesn't change. + ( + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.OFFLINE, + PresenceState.OFFLINE, + PresenceState.OFFLINE, + PresenceState.OFFLINE, + PresenceState.OFFLINE, + ), + # If the second device has a "lower" state it should fallback to it, + # except for "busy" which overrides. + ( + PresenceState.BUSY, + PresenceState.ONLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.BUSY, + PresenceState.UNAVAILABLE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.BUSY, + PresenceState.OFFLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.ONLINE, + PresenceState.OFFLINE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + # If the second device has a "higher" state it should override. + ( + PresenceState.ONLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.OFFLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.OFFLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.OFFLINE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ] + ], + name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[5] else 'monolith'}", + ) + @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) + def test_set_presence_from_syncing_multi_device( + self, + dev_1_state: str, + dev_2_state: str, + expected_state_1: str, + expected_state_2: str, + expected_state_3: str, + test_with_workers: bool, + ) -> None: + """ + Test the behaviour of multiple devices syncing at the same time. + + Roughly the user's presence state should be set to the "highest" priority + of all the devices. When a device then goes offline its state should be + discarded and the next highest should win. + + Note that these tests use the idle timer (and don't close the syncs), it + is unlikely that a *single* sync would last this long, but is close enough + to continually syncing with that current state. + """ + user_id = f"@test:{self.hs.config.server.server_name}" + + # By default, we call /sync against the main process. + worker_presence_handler = self.presence_handler + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "synchrotron"} + ) + worker_presence_handler = worker_to_sync_against.get_presence_handler() + + # 1. Sync with the first device. + self.get_success( + worker_presence_handler.user_syncing( + user_id, + "dev-1", + affect_presence=dev_1_state != PresenceState.OFFLINE, + presence_state=dev_1_state, + ), + by=0.01, ) + # 2. Wait half the idle timer. + self.reactor.advance(IDLE_TIMER / 1000 / 2) + self.reactor.pump([0.1]) + + # 3. Sync with the second device. self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + worker_presence_handler.user_syncing( + user_id, + "dev-2", + affect_presence=dev_2_state != PresenceState.OFFLINE, + presence_state=dev_2_state, + ), + by=0.01, ) + # 4. Assert the expected presence state. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) - # we should now be online - self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.state, expected_state_1) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_1) - def test_set_presence_from_syncing_keeps_status(self) -> None: - """Test that presence set by syncing retains status message""" - user_id = "@test:server" - status_msg = "I'm here!" + # When testing with workers, make another random sync (with any *different* + # user) to keep the process information from expiring. + # + # This is due to EXTERNAL_PROCESS_EXPIRY being equivalent to IDLE_TIMER. + if test_with_workers: + with self.get_success( + worker_presence_handler.user_syncing( + f"@other-user:{self.hs.config.server.server_name}", + "dev-3", + affect_presence=True, + presence_state=PresenceState.ONLINE, + ), + by=0.01, + ): + pass - self._set_presencestate_with_status_msg( - user_id, PresenceState.UNAVAILABLE, status_msg + # 5. Advance such that the first device should be discarded (the idle timer), + # then pump so _handle_timeouts function to called. + self.reactor.advance(IDLE_TIMER / 1000 / 2) + self.reactor.pump([0.01]) + + # 6. Assert the expected presence state. + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) ) + self.assertEqual(state.state, expected_state_2) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_2) - self.get_success( - self.presence_handler.user_syncing(user_id, True, PresenceState.ONLINE) + # 7. Advance such that the second device should be discarded (half the idle timer), + # then pump so _handle_timeouts function to called. + self.reactor.advance(IDLE_TIMER / 1000 / 2) + self.reactor.pump([0.1]) + + # 8. The devices are still "syncing" (the sync context managers were never + # closed), so might idle. + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) ) + self.assertEqual(state.state, expected_state_3) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_3) + + @parameterized.expand( + # A list of tuples of 4 strings: + # + # * The presence state of device 1. + # * The presence state of device 2. + # * The expected user presence state after both devices have synced. + # * The expected user presence state after device 1 has stopped syncing. + # * True to use workers, False a monolith. + [ + (*cases, workers) + for workers in (False, True) + for cases in [ + # If both devices have the same state, nothing exciting should happen. + ( + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.OFFLINE, + PresenceState.OFFLINE, + PresenceState.OFFLINE, + PresenceState.OFFLINE, + ), + # If the second device has a "lower" state it should fallback to it, + # except for "busy" which overrides. + ( + PresenceState.BUSY, + PresenceState.ONLINE, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.BUSY, + PresenceState.UNAVAILABLE, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.BUSY, + PresenceState.OFFLINE, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + PresenceState.ONLINE, + PresenceState.UNAVAILABLE, + ), + ( + PresenceState.ONLINE, + PresenceState.OFFLINE, + PresenceState.ONLINE, + PresenceState.OFFLINE, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + PresenceState.UNAVAILABLE, + PresenceState.OFFLINE, + ), + # If the second device has a "higher" state it should override. + ( + PresenceState.ONLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.OFFLINE, + PresenceState.BUSY, + PresenceState.BUSY, + PresenceState.BUSY, + ), + ( + PresenceState.UNAVAILABLE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + ), + ( + PresenceState.OFFLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + PresenceState.ONLINE, + ), + ( + PresenceState.OFFLINE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + PresenceState.UNAVAILABLE, + ), + ] + ], + name_func=lambda testcase_func, param_num, params: f"{testcase_func.__name__}_{param_num}_{'workers' if params.args[4] else 'monolith'}", + ) + @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) + def test_set_presence_from_non_syncing_multi_device( + self, + dev_1_state: str, + dev_2_state: str, + expected_state_1: str, + expected_state_2: str, + test_with_workers: bool, + ) -> None: + """ + Test the behaviour of multiple devices syncing at the same time. + + Roughly the user's presence state should be set to the "highest" priority + of all the devices. When a device then goes offline its state should be + discarded and the next highest should win. + + Note that these tests use the idle timer (and don't close the syncs), it + is unlikely that a *single* sync would last this long, but is close enough + to continually syncing with that current state. + """ + user_id = f"@test:{self.hs.config.server.server_name}" + + # By default, we call /sync against the main process. + worker_presence_handler = self.presence_handler + if test_with_workers: + # Create a worker and use it to handle /sync traffic instead. + # This is used to test that presence changes get replicated from workers + # to the main process correctly. + worker_to_sync_against = self.make_worker_hs( + "synapse.app.generic_worker", {"worker_name": "synchrotron"} + ) + worker_presence_handler = worker_to_sync_against.get_presence_handler() + + # 1. Sync with the first device. + sync_1 = self.get_success( + worker_presence_handler.user_syncing( + user_id, + "dev-1", + affect_presence=dev_1_state != PresenceState.OFFLINE, + presence_state=dev_1_state, + ), + by=0.1, + ) + + # 2. Sync with the second device. + sync_2 = self.get_success( + worker_presence_handler.user_syncing( + user_id, + "dev-2", + affect_presence=dev_2_state != PresenceState.OFFLINE, + presence_state=dev_2_state, + ), + by=0.1, + ) + + # 3. Assert the expected presence state. + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_1) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_1) + + # 4. Disconnect the first device. + with sync_1: + pass + + # 5. Advance such that the first device should be discarded (the sync timeout), + # then pump so _handle_timeouts function to called. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 1000) + self.reactor.pump([5]) + # 6. Assert the expected presence state. state = self.get_success( self.presence_handler.get_state(UserID.from_string(user_id)) ) + self.assertEqual(state.state, expected_state_2) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, expected_state_2) + + # 7. Disconnect the second device. + with sync_2: + pass + + # 8. Advance such that the second device should be discarded (the sync timeout), + # then pump so _handle_timeouts function to called. + if dev_1_state == PresenceState.BUSY or dev_2_state == PresenceState.BUSY: + timeout = BUSY_ONLINE_TIMEOUT + else: + timeout = SYNC_ONLINE_TIMEOUT + self.reactor.advance(timeout / 1000) + self.reactor.pump([5]) + + # 9. There are no more devices, should be offline. + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + if test_with_workers: + state = self.get_success( + worker_presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + + def test_set_presence_from_syncing_keeps_status(self) -> None: + """Test that presence set by syncing retains status message""" + status_msg = "I'm here!" + + self._set_presencestate_with_status_msg(PresenceState.UNAVAILABLE, status_msg) + + self.get_success( + self.presence_handler.user_syncing( + self.user_id, self.device_id, True, PresenceState.ONLINE + ) + ) + + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) # our status message should be the same as it was before self.assertEqual(state.status_msg, status_msg) @parameterized.expand([(False,), (True,)]) - @unittest.override_config( - { - "experimental_features": { - "msc3026_enabled": True, - }, - } - ) + @unittest.override_config({"experimental_features": {"msc3026_enabled": True}}) def test_set_presence_from_syncing_keeps_busy( self, test_with_workers: bool ) -> None: @@ -741,7 +1405,6 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): test_with_workers: If True, check the presence state of the user by calling /sync against a worker, rather than the main process. """ - user_id = "@test:server" status_msg = "I'm busy!" # By default, we call /sync against the main process. @@ -751,48 +1414,60 @@ class PresenceHandlerTestCase(BaseMultiWorkerStreamTestCase): # This is used to test that presence changes get replicated from workers # to the main process correctly. worker_to_sync_against = self.make_worker_hs( - "synapse.app.generic_worker", {"worker_name": "presence_writer"} + "synapse.app.generic_worker", {"worker_name": "synchrotron"} ) # Set presence to BUSY - self._set_presencestate_with_status_msg(user_id, PresenceState.BUSY, status_msg) + self._set_presencestate_with_status_msg(PresenceState.BUSY, status_msg) # Perform a sync with a presence state other than busy. This should NOT change # our presence status; we only change from busy if we explicitly set it via # /presence/*. self.get_success( worker_to_sync_against.get_presence_handler().user_syncing( - user_id, True, PresenceState.ONLINE - ) + self.user_id, self.device_id, True, PresenceState.ONLINE + ), + by=0.1, ) # Check against the main process that the user's presence did not change. - state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) # we should still be busy self.assertEqual(state.state, PresenceState.BUSY) + # Advance such that the device would be discarded if it was not busy, + # then pump so _handle_timeouts function to called. + self.reactor.advance(IDLE_TIMER / 1000) + self.reactor.pump([5]) + + # The account should still be busy. + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) + self.assertEqual(state.state, PresenceState.BUSY) + + # Ensure that a /presence call can set the user *off* busy. + self._set_presencestate_with_status_msg(PresenceState.ONLINE, status_msg) + + state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) + self.assertEqual(state.state, PresenceState.ONLINE) + def _set_presencestate_with_status_msg( - self, user_id: str, state: str, status_msg: Optional[str] + self, state: str, status_msg: Optional[str] ) -> None: """Set a PresenceState and status_msg and check the result. Args: - user_id: User for that the status is to be set. state: The new PresenceState. status_msg: Status message that is to be set. """ self.get_success( self.presence_handler.set_state( - UserID.from_string(user_id), + self.user_id_obj, + self.device_id, {"presence": state, "status_msg": status_msg}, ) ) - new_state = self.get_success( - self.presence_handler.get_state(UserID.from_string(user_id)) - ) + new_state = self.get_success(self.presence_handler.get_state(self.user_id_obj)) self.assertEqual(new_state.state, state) self.assertEqual(new_state.status_msg, status_msg) @@ -812,8 +1487,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) @@ -849,11 +1530,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) now_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) rows, upto_token, limited = self.get_success( self.queue.get_replication_rows("master", prev_token, now_token, 10) @@ -892,8 +1579,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) self.reactor.advance(10 * 60 * 1000) @@ -908,8 +1601,14 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) @@ -936,11 +1635,17 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) self.reactor.advance(2 * 60 * 1000) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) self.reactor.advance(4 * 60 * 1000) @@ -952,15 +1657,18 @@ class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): self.assertEqual(upto_token, now_token) self.assertFalse(limited) - expected_rows = [ - (2, ("dest3", "@user3:test")), - ] self.assertCountEqual(rows, []) prev_token = self.queue.get_current_token(self.instance_name) - self.queue.send_presence_to_destinations((state1, state2), ("dest1", "dest2")) - self.queue.send_presence_to_destinations((state3,), ("dest3",)) + self.get_success( + self.queue.send_presence_to_destinations( + (state1, state2), ("dest1", "dest2") + ) + ) + self.get_success( + self.queue.send_presence_to_destinations((state3,), ("dest3",)) + ) now_token = self.queue.get_current_token(self.instance_name) @@ -993,7 +1701,6 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver( "server", - federation_http_client=None, federation_sender=Mock(spec=FederationSender), ) return hs @@ -1033,7 +1740,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test2 as online, test will be offline with a last_active of 0 self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test2:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) self.reactor.pump([0]) # Wait for presence updates to be handled @@ -1080,7 +1789,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Mark test as online self.get_success( self.presence_handler.set_state( - UserID.from_string("@test:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) @@ -1088,7 +1799,9 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): # Note we don't join them to the room yet self.get_success( self.presence_handler.set_state( - UserID.from_string("@test2:server"), {"presence": PresenceState.ONLINE} + UserID.from_string("@test2:server"), + "dev-1", + {"presence": PresenceState.ONLINE}, ) ) @@ -1145,7 +1858,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase): ) event = self.get_success( - builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) + builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) ) self.get_success(self.federation_event_handler.on_receive_pdu(hostname, event)) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 64a9a22afe..f9b292b9ec 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Awaitable, Callable, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -26,7 +26,6 @@ from synapse.types import JsonDict, UserID from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class ProfileTestCase(unittest.HomeserverTestCase): @@ -35,7 +34,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - self.mock_federation = Mock() + self.mock_federation = AsyncMock() self.mock_registry = Mock() self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} @@ -80,11 +79,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - ( - self.get_success( - self.store.get_profile_displayname(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_displayname(self.frank))), "Frank Jr.", ) @@ -96,11 +91,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - ( - self.get_success( - self.store.get_profile_displayname(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_displayname(self.frank))), "Frank", ) @@ -112,7 +103,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - self.get_success(self.store.get_profile_displayname(self.frank.localpart)) + self.get_success(self.store.get_profile_displayname(self.frank)) ) def test_set_my_name_if_disabled(self) -> None: @@ -122,11 +113,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.store.set_profile_displayname(self.frank, "Frank")) self.assertEqual( - ( - self.get_success( - self.store.get_profile_displayname(self.frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_displayname(self.frank))), "Frank", ) @@ -147,9 +134,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_get_other_name(self) -> None: - self.mock_federation.make_query.return_value = make_awaitable( - {"displayname": "Alice"} - ) + self.mock_federation.make_query.return_value = {"displayname": "Alice"} displayname = self.get_success(self.handler.get_displayname(self.alice)) @@ -191,6 +176,16 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("http://my.server/me.png", avatar_url) + def test_get_profile_empty_displayname(self) -> None: + self.get_success(self.store.set_profile_displayname(self.frank, None)) + self.get_success( + self.store.set_profile_avatar_url(self.frank, "http://my.server/me.png") + ) + + profile = self.get_success(self.handler.get_profile(self.frank.to_string())) + + self.assertEqual("http://my.server/me.png", profile["avatar_url"]) + def test_set_my_avatar(self) -> None: self.get_success( self.handler.set_avatar_url( @@ -201,7 +196,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + (self.get_success(self.store.get_profile_avatar_url(self.frank))), "http://my.server/pic.gif", ) @@ -215,7 +210,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + (self.get_success(self.store.get_profile_avatar_url(self.frank))), "http://my.server/me.png", ) @@ -229,7 +224,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertIsNone( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + (self.get_success(self.store.get_profile_avatar_url(self.frank))), ) def test_set_my_avatar_if_disabled(self) -> None: @@ -241,7 +236,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) self.assertEqual( - (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), + (self.get_success(self.store.get_profile_avatar_url(self.frank))), "http://my.server/me.png", ) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 73822b07a5..e9fbf32c7c 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -13,11 +13,11 @@ # limitations under the License. from typing import Any, Collection, List, Optional, Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor -from synapse.api.auth import Auth +from synapse.api.auth.internal import InternalAuth from synapse.api.constants import UserTypes from synapse.api.errors import ( CodeMessageException, @@ -38,7 +38,6 @@ from synapse.types import ( ) from synapse.util import Clock -from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import mock_getRawHeaders @@ -203,24 +202,22 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self) -> None: - self.store.count_monthly_users = Mock( # type: ignore[assignment] - return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) + self.store.count_monthly_users = AsyncMock( # type: ignore[method-assign] + return_value=self.hs.config.server.max_mau_value - 1 ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self) -> None: - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), ResourceLimitError, ) - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -229,15 +226,13 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_register_mau_blocked(self) -> None: - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=self.lots_of_users) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError ) - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError @@ -292,7 +287,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"auto_join_rooms": ["#room:test"]}) def test_auto_create_auto_join_rooms_when_user_is_not_a_real_user(self) -> None: room_alias_str = "#room:test" - self.store.is_real_user = Mock(return_value=make_awaitable(False)) + self.store.is_real_user = AsyncMock(return_value=False) user_id = self.get_success(self.handler.register_user(localpart="support")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -304,8 +299,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_the_first_real_user(self) -> None: room_alias_str = "#room:test" - self.store.count_real_users = Mock(return_value=make_awaitable(1)) # type: ignore[assignment] - self.store.is_real_user = Mock(return_value=make_awaitable(True)) + self.store.count_real_users = AsyncMock(return_value=1) # type: ignore[method-assign] + self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) directory_handler = self.hs.get_directory_handler() @@ -319,8 +314,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase): def test_auto_create_auto_join_rooms_when_user_is_not_the_first_real_user( self, ) -> None: - self.store.count_real_users = Mock(return_value=make_awaitable(2)) # type: ignore[assignment] - self.store.is_real_user = Mock(return_value=make_awaitable(True)) + self.store.count_real_users = AsyncMock(return_value=2) # type: ignore[method-assign] + self.store.is_real_user = AsyncMock(return_value=True) user_id = self.get_success(self.handler.register_user(localpart="real")) rooms = self.get_success(self.store.get_rooms_for_user(user_id)) self.assertEqual(len(rooms), 0) @@ -587,17 +582,16 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertFalse(self.get_success(d)) def test_invalid_user_id(self) -> None: - invalid_user_id = "+abcd" + invalid_user_id = "^abcd" self.get_failure( self.handler.register_user(localpart=invalid_user_id), SynapseError ) - @override_config({"experimental_features": {"msc4009_e164_mxids": True}}) - def text_extended_user_ids(self) -> None: - """+ should be allowed according to MSC4009.""" - valid_user_id = "+1234" + def test_special_chars(self) -> None: + """Ensure that characters which are allowed in Matrix IDs work.""" + valid_user_id = "a1234_-./=+" user_id = self.get_success(self.handler.register_user(localpart=valid_user_id)) - self.assertEqual(user_id, valid_user_id) + self.assertEqual(user_id, f"@{valid_user_id}:test") def test_invalid_user_id_length(self) -> None: invalid_user_id = "x" * 256 @@ -683,7 +677,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - auth = Auth(self.hs) + auth = InternalAuth(self.hs) requester = self.get_success(auth.get_user_by_req(request)) self.assertTrue(requester.shadow_banned) diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py index a444d822cd..3e28117e2c 100644 --- a/tests/handlers/test_room_member.py +++ b/tests/handlers/test_room_member.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -16,7 +16,6 @@ from synapse.util import Clock from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request -from tests.test_utils import make_awaitable from tests.unittest import ( FederatingHomeserverTestCase, HomeserverTestCase, @@ -154,25 +153,21 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase): None, ) - mock_make_membership_event = Mock( - return_value=make_awaitable( - ( - self.OTHER_SERVER_NAME, - join_event, - self.hs.config.server.default_room_version, - ) + mock_make_membership_event = AsyncMock( + return_value=( + self.OTHER_SERVER_NAME, + join_event, + self.hs.config.server.default_room_version, ) ) - mock_send_join = Mock( - return_value=make_awaitable( - SendJoinResult( - join_event, - self.OTHER_SERVER_NAME, - state=[create_event], - auth_chain=[create_event], - partial_state=False, - servers_in_room=frozenset(), - ) + mock_send_join = AsyncMock( + return_value=SendJoinResult( + join_event, + self.OTHER_SERVER_NAME, + state=[create_event], + auth_chain=[create_event], + partial_state=False, + servers_in_room=frozenset(), ) ) @@ -333,6 +328,27 @@ class RoomMemberMasterHandlerTestCase(HomeserverTestCase): self.get_success(self.store.is_locally_forgotten_room(self.room_id)) ) + def test_leave_and_unforget(self) -> None: + """Tests if rejoining a room unforgets the room, so that it shows up in sync again.""" + self.helper.join(self.room_id, user=self.bob, tok=self.bob_token) + + # alice is not the last room member that leaves and forgets the room + self.helper.leave(self.room_id, user=self.alice, tok=self.alice_token) + self.get_success(self.handler.forget(self.alice_ID, self.room_id)) + self.assertTrue( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) + + self.helper.join(self.room_id, user=self.alice, tok=self.alice_token) + self.assertFalse( + self.get_success(self.store.did_forget(self.alice, self.room_id)) + ) + + # the server has not forgotten the room + self.assertFalse( + self.get_success(self.store.is_locally_forgotten_room(self.room_id)) + ) + @override_config({"forget_rooms_on_leave": True}) def test_leave_and_auto_forget(self) -> None: """Tests the `forget_rooms_on_leave` config option.""" diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index b5c772a7ae..00f4e181e8 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Dict, Optional, Set, Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import attr @@ -25,7 +25,6 @@ from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config # Check if we have the dependencies to run the tests. @@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # send a mocked-up SAML response to the callback saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) @@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # Map a user via SSO. saml_response = FakeAuthnResponse( @@ -206,11 +205,11 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # mock out the error renderer too sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"}) request = _mock_request() @@ -227,9 +226,9 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler and error renderer auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] sso_handler = self.hs.get_sso_handler() - sso_handler.render_error = Mock(return_value=None) # type: ignore[assignment] + sso_handler.render_error = Mock(return_value=None) # type: ignore[method-assign] # register a user to occupy the first-choice MXID store = self.hs.get_datastores().main @@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # stub out the auth handler auth_handler = self.hs.get_auth_handler() - auth_handler.complete_sso_login = simple_async_mock() # type: ignore[assignment] + auth_handler.complete_sso_login = AsyncMock() # type: ignore[method-assign] # The response doesn't have the proper userGroup or department. saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"}) diff --git a/tests/handlers/test_send_email.py b/tests/handlers/test_send_email.py index 8b6e4a40b6..a066745d70 100644 --- a/tests/handlers/test_send_email.py +++ b/tests/handlers/test_send_email.py @@ -13,19 +13,40 @@ # limitations under the License. -from typing import Callable, List, Tuple +from typing import Callable, List, Tuple, Type, Union +from unittest.mock import patch from zope.interface import implementer from twisted.internet import defer -from twisted.internet.address import IPv4Address +from twisted.internet._sslverify import ClientTLSOptions +from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import ensureDeferred +from twisted.internet.interfaces import IProtocolFactory +from twisted.internet.ssl import ContextFactory from twisted.mail import interfaces, smtp from tests.server import FakeTransport from tests.unittest import HomeserverTestCase, override_config +def TestingESMTPTLSClientFactory( + contextFactory: ContextFactory, + _connectWrapped: bool, + wrappedProtocol: IProtocolFactory, +) -> IProtocolFactory: + """We use this to pass through in testing without using TLS, but + saving the context information to check that it would have happened. + + Note that this is what the MemoryReactor does on connectSSL. + It only saves the contextFactory, but starts the connection with the + underlying Factory. + See: L{twisted.internet.testing.MemoryReactor.connectSSL}""" + + wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined] + return wrappedProtocol + + @implementer(interfaces.IMessageDelivery) class _DummyMessageDelivery: def __init__(self) -> None: @@ -75,7 +96,13 @@ class _DummyMessage: pass -class SendEmailHandlerTestCase(HomeserverTestCase): +class SendEmailHandlerTestCaseIPv4(HomeserverTestCase): + ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "127.0.0.1" + def test_send_email(self) -> None: """Happy-path test that we can send email to a non-TLS server.""" h = self.hs.get_send_email_handler() @@ -89,7 +116,7 @@ class SendEmailHandlerTestCase(HomeserverTestCase): (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[ 0 ] - self.assertEqual(host, "localhost") + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 25) # wire it up to an SMTP server @@ -105,7 +132,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -118,6 +147,10 @@ class SendEmailHandlerTestCase(HomeserverTestCase): self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + @patch( + "synapse.handlers.send_email.TLSMemoryBIOFactory", + TestingESMTPTLSClientFactory, + ) @override_config( { "email": { @@ -135,17 +168,23 @@ class SendEmailHandlerTestCase(HomeserverTestCase): ) ) # there should be an attempt to connect to localhost:465 - self.assertEqual(len(self.reactor.sslClients), 1) + self.assertEqual(len(self.reactor.tcpClients), 1) ( host, port, client_factory, - contextFactory, _timeout, _bindAddress, - ) = self.reactor.sslClients[0] - self.assertEqual(host, "localhost") + ) = self.reactor.tcpClients[0] + self.assertEqual(host, self.reactor.lookups["localhost"]) self.assertEqual(port, 465) + # We need to make sure that TLS is happenning + self.assertIsInstance( + client_factory._wrappedFactory._testingContextFactory, + ClientTLSOptions, + ) + # And since we use endpoints, they go through reactor.connectTCP + # which works differently to connectSSL on the testing reactor # wire it up to an SMTP server message_delivery = _DummyMessageDelivery() @@ -160,7 +199,9 @@ class SendEmailHandlerTestCase(HomeserverTestCase): FakeTransport( client_protocol, self.reactor, - peer_address=IPv4Address("TCP", "127.0.0.1", 1234), + peer_address=self.ip_class( + "TCP", self.reactor.lookups["localhost"], 1234 + ), ) ) @@ -172,3 +213,11 @@ class SendEmailHandlerTestCase(HomeserverTestCase): user, msg = message_delivery.messages.pop() self.assertEqual(str(user), "foo@bar.com") self.assertIn(b"Subject: test subject", msg) + + +class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4): + ip_class = IPv6Address + + def setUp(self) -> None: + super().setUp() + self.reactor.lookups["localhost"] = "::1" diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index 0d9a3de92a..948d04fc32 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import AsyncMock, Mock, patch from twisted.test.proto_helpers import MemoryReactor @@ -29,7 +29,6 @@ from synapse.util import Clock import tests.unittest import tests.utils -from tests.test_utils import make_awaitable class SyncTestCase(tests.unittest.HomeserverTestCase): @@ -163,7 +162,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() self.store.get_rooms_for_user.invalidate_all() - self.get_success(self.store._get_event_cache.clear()) + self.store._get_event_cache.clear() self.store._event_ref.clear() # The rooms should be excluded from the sync response. @@ -253,8 +252,8 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): mocked_get_prev_events = patch.object( self.hs.get_datastores().main, "get_prev_events_for_room", - new_callable=MagicMock, - return_value=make_awaitable([last_room_creation_event_id]), + new_callable=AsyncMock, + return_value=[last_room_creation_event_id], ) with mocked_get_prev_events: self.helper.join(room_id, eve, tok=eve_token) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 94518a7196..3060bc9744 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -15,7 +15,9 @@ import json from typing import Dict, List, Set -from unittest.mock import ANY, Mock, call +from unittest.mock import ANY, AsyncMock, Mock, call + +from netaddr import IPSet from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource @@ -24,13 +26,13 @@ from synapse.api.constants import EduTypes from synapse.api.errors import AuthError from synapse.federation.transport.server import TransportLayerServer from synapse.handlers.typing import TypingWriterHandler +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.server import HomeServer -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict, Requester, StreamKeyType, UserID, create_requester from synapse.util import Clock from tests import unittest from tests.server import ThreadedMemoryReactorClock -from tests.test_utils import make_awaitable from tests.unittest import override_config # Some local users to test with @@ -71,11 +73,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = make_awaitable(True) + mock_keyring.verify_json_for_server = AsyncMock(return_value=True) # we mock out the federation client too - self.mock_federation_client = Mock(spec=["put_json"]) - self.mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) + self.mock_federation_client = AsyncMock(spec=["put_json"]) + self.mock_federation_client.put_json.return_value = (200, "OK") + self.mock_federation_client.agent = MatrixFederationAgent( + reactor, + tls_client_options_factory=None, + user_agent=b"SynapseInTrialTest/0.0.0", + ip_allowlist=None, + ip_blocklist=IPSet(), + ) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -111,20 +120,16 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main - self.datastore.get_destination_retry_timings = Mock( - return_value=make_awaitable(None) - ) - - self.datastore.get_device_updates_by_remote = Mock( # type: ignore[assignment] - return_value=make_awaitable((0, [])) + self.datastore.get_device_updates_by_remote = AsyncMock( # type: ignore[method-assign] + return_value=(0, []) ) - self.datastore.get_destination_last_successful_stream_ordering = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self.datastore.get_destination_last_successful_stream_ordering = AsyncMock( # type: ignore[method-assign] + return_value=None ) - self.datastore.get_received_txn_response = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self.datastore.get_received_txn_response = AsyncMock( # type: ignore[method-assign] + return_value=None ) self.room_members: List[UserID] = [] @@ -136,25 +141,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): raise AuthError(401, "User is not in the room") return None - hs.get_auth().check_user_in_room = Mock( # type: ignore[assignment] + hs.get_auth().check_user_in_room = Mock( # type: ignore[method-assign] side_effect=check_user_in_room ) async def check_host_in_room(room_id: str, server_name: str) -> bool: return room_id == ROOM_ID - hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[assignment] + hs.get_event_auth_handler().is_host_in_room = Mock( # type: ignore[method-assign] side_effect=check_host_in_room ) async def get_current_hosts_in_room(room_id: str) -> Set[str]: return {member.domain for member in self.room_members} - hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) - hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[assignment] + hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = Mock( # type: ignore[method-assign] side_effect=get_current_hosts_in_room ) @@ -163,27 +168,25 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_users_in_room = Mock(side_effect=get_users_in_room) - self.datastore.get_user_directory_stream_pos = Mock( # type: ignore[assignment] - side_effect=( - # we deliberately return a non-None stream pos to avoid - # doing an initial_sync - lambda: make_awaitable(1) - ) + self.datastore.get_user_directory_stream_pos = AsyncMock( # type: ignore[method-assign] + # we deliberately return a non-None stream pos to avoid + # doing an initial_sync + return_value=1 ) - self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[assignment] + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) # type: ignore[method-assign] - self.datastore.get_to_device_stream_token = Mock( # type: ignore[assignment] - side_effect=lambda: 0 + self.datastore.get_to_device_stream_token = Mock( # type: ignore[method-assign] + return_value=0 ) - self.datastore.get_new_device_msgs_for_remote = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kargs: make_awaitable(([], 0)) + self.datastore.get_new_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] + return_value=([], 0) ) - self.datastore.delete_device_msgs_for_remote = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kargs: make_awaitable(None) + self.datastore.delete_device_msgs_for_remote = AsyncMock( # type: ignore[method-assign] + return_value=None ) - self.datastore.set_received_txn_response = Mock( # type: ignore[assignment] - side_effect=lambda *args, **kwargs: make_awaitable(None) + self.datastore.set_received_txn_response = AsyncMock( # type: ignore[method-assign] + return_value=None ) def test_started_typing_local(self) -> None: @@ -200,7 +203,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( @@ -246,8 +251,8 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), json_data_callback=ANY, long_retries=True, - backoff_on_404=True, try_trailing_slash_on_400=True, + backoff_on_all_error_codes=True, ) def test_started_typing_remote_recv(self) -> None: @@ -270,7 +275,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 200) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 1) events = self.get_success( @@ -346,7 +353,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.mock_federation_client.put_json.assert_called_once_with( "farm", @@ -361,7 +370,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ), json_data_callback=ANY, long_retries=True, - backoff_on_404=True, + backoff_on_all_error_codes=True, try_trailing_slash_on_400=True, ) @@ -396,7 +405,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 1, rooms=[ROOM_ID])] + ) self.on_new_event.reset_mock() self.assertEqual(self.event_source.get_current_key(), 1) @@ -422,7 +433,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.reactor.pump([16]) - self.on_new_event.assert_has_calls([call("typing_key", 2, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 2, rooms=[ROOM_ID])] + ) self.assertEqual(self.event_source.get_current_key(), 2) events = self.get_success( @@ -456,7 +469,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.on_new_event.assert_has_calls([call("typing_key", 3, rooms=[ROOM_ID])]) + self.on_new_event.assert_has_calls( + [call(StreamKeyType.TYPING, 3, rooms=[ROOM_ID])] + ) self.on_new_event.reset_mock() self.assertEqual(self.event_source.get_current_key(), 3) diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 15a7dc6818..b5f15aa7d4 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Tuple -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch from urllib.parse import quote from twisted.test.proto_helpers import MemoryReactor @@ -30,7 +30,7 @@ from synapse.util import Clock from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables -from tests.test_utils import event_injection, make_awaitable +from tests.test_utils import event_injection from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -356,7 +356,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): support_user_id, ProfileInfo("I love support me", None) ) ) - profile = self.get_success(self.store.get_user_in_directory(support_user_id)) + profile = self.get_success(self.store._get_user_in_directory(support_user_id)) self.assertIsNone(profile) display_name = "display_name" @@ -364,7 +364,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.get_success( self.handler.handle_local_profile_change(regular_user_id, profile_info) ) - profile = self.get_success(self.store.get_user_in_directory(regular_user_id)) + profile = self.get_success(self.store._get_user_in_directory(regular_user_id)) assert profile is not None self.assertTrue(profile["display_name"] == display_name) @@ -383,7 +383,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # profile is in directory - profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + profile = self.get_success(self.store._get_user_in_directory(r_user_id)) assert profile is not None self.assertTrue(profile["display_name"] == display_name) @@ -392,7 +392,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.get_success(self.handler.handle_local_user_deactivated(r_user_id)) # profile is not in directory - profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + profile = self.get_success(self.store._get_user_in_directory(r_user_id)) self.assertIsNone(profile) # update profile after deactivation @@ -401,7 +401,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # profile is furthermore not in directory - profile = self.get_success(self.store.get_user_in_directory(r_user_id)) + profile = self.get_success(self.store._get_user_in_directory(r_user_id)) self.assertIsNone(profile) def test_handle_local_profile_change_with_appservice_user(self) -> None: @@ -411,7 +411,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # profile is not in directory - profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + profile = self.get_success(self.store._get_user_in_directory(as_user_id)) self.assertIsNone(profile) # update profile @@ -421,13 +421,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # profile is still not in directory - profile = self.get_success(self.store.get_user_in_directory(as_user_id)) + profile = self.get_success(self.store._get_user_in_directory(as_user_id)) self.assertIsNone(profile) def test_handle_local_profile_change_with_appservice_sender(self) -> None: # profile is not in directory profile = self.get_success( - self.store.get_user_in_directory(self.appservice.sender) + self.store._get_user_in_directory(self.appservice.sender) ) self.assertIsNone(profile) @@ -441,11 +441,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # profile is still not in directory profile = self.get_success( - self.store.get_user_in_directory(self.appservice.sender) + self.store._get_user_in_directory(self.appservice.sender) ) self.assertIsNone(profile) def test_handle_user_deactivated_support_user(self) -> None: + """Ensure a support user doesn't get added to the user directory after deactivation.""" s_user_id = "@support:test" self.get_success( self.store.register_user( @@ -453,14 +454,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) - with patch.object( - self.store, "remove_from_user_dir", mock_remove_from_user_dir - ): - self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) - # BUG: the correct spelling is assert_not_called, but that makes the test fail - # and it's not clear that this is actually the behaviour we want. - mock_remove_from_user_dir.not_called() + # The profile should not be in the directory. + profile = self.get_success(self.store._get_user_in_directory(s_user_id)) + self.assertIsNone(profile) + + # Remove the user from the directory. + self.get_success(self.handler.handle_local_user_deactivated(s_user_id)) + + # The profile should still not be in the user directory. + profile = self.get_success(self.store._get_user_in_directory(s_user_id)) + self.assertIsNone(profile) def test_handle_user_deactivated_regular_user(self) -> None: r_user_id = "@regular:test" @@ -468,7 +471,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) + mock_remove_from_user_dir = AsyncMock(return_value=None) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/handlers/test_worker_lock.py b/tests/handlers/test_worker_lock.py new file mode 100644 index 0000000000..73e548726c --- /dev/null +++ b/tests/handlers/test_worker_lock.py @@ -0,0 +1,74 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.replication._base import BaseMultiWorkerStreamTestCase + + +class WorkerLockTestCase(unittest.HomeserverTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.worker_lock_handler = self.hs.get_worker_locks_handler() + + def test_wait_for_lock_locally(self) -> None: + """Test waiting for a lock on a single worker""" + + lock1 = self.worker_lock_handler.acquire_lock("name", "key") + self.get_success(lock1.__aenter__()) + + lock2 = self.worker_lock_handler.acquire_lock("name", "key") + d2 = defer.ensureDeferred(lock2.__aenter__()) + self.assertNoResult(d2) + + self.get_success(lock1.__aexit__(None, None, None)) + + self.get_success(d2) + self.get_success(lock2.__aexit__(None, None, None)) + + +class WorkerLockWorkersTestCase(BaseMultiWorkerStreamTestCase): + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self.main_worker_lock_handler = self.hs.get_worker_locks_handler() + + def test_wait_for_lock_worker(self) -> None: + """Test waiting for a lock on another worker""" + + worker = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={ + "redis": {"enabled": True}, + }, + ) + worker_lock_handler = worker.get_worker_locks_handler() + + lock1 = self.main_worker_lock_handler.acquire_lock("name", "key") + self.get_success(lock1.__aenter__()) + + lock2 = worker_lock_handler.acquire_lock("name", "key") + d2 = defer.ensureDeferred(lock2.__aenter__()) + self.assertNoResult(d2) + + self.get_success(lock1.__aexit__(None, None, None)) + + self.get_success(d2) + self.get_success(lock2.__aexit__(None, None, None)) diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 105b4caefa..9f63fa6fa8 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -14,8 +14,8 @@ import base64 import logging import os -from typing import Any, Awaitable, Callable, Generator, List, Optional, cast -from unittest.mock import Mock, patch +from typing import Generator, List, Optional, cast +from unittest.mock import AsyncMock, call, patch import treq from netaddr import IPSet @@ -41,7 +41,7 @@ from twisted.web.iweb import IPolicyForHTTPS, IResponse from synapse.config.homeserver import HomeServerConfig from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent -from synapse.http.federation.srv_resolver import Server +from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import ( WELL_KNOWN_MAX_SIZE, WellKnownResolver, @@ -68,21 +68,11 @@ from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) -# Once Async Mocks or lambdas are supported this can go away. -def generate_resolve_service( - result: List[Server], -) -> Callable[[Any], Awaitable[List[Server]]]: - async def resolve_service(_: Any) -> List[Server]: - return result - - return resolve_service - - class MatrixFederationAgentTests(unittest.TestCase): def setUp(self) -> None: self.reactor = ThreadedMemoryReactorClock() - self.mock_resolver = Mock() + self.mock_resolver = AsyncMock(spec=SrvResolver) config_dict = default_config("test", parse=False) config_dict["federation_custom_ca_list"] = [get_test_ca_cert_file()] @@ -292,7 +282,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.agent = self._make_agent() self.reactor.lookups["testserv"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv:8448/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -393,7 +383,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["proxy.com"] = "9.9.9.9" - test_d = self._make_get_request(b"matrix://testserv:8448/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv:8448/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -514,7 +504,7 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(response.code, 200) # Send the body - request.write('{ "a": 1 }'.encode("ascii")) + request.write(b'{ "a": 1 }') request.finish() self.reactor.pump((0.1,)) @@ -532,7 +522,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.4"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://1.2.3.4/foo/bar") + test_d = self._make_get_request(b"matrix-federation://1.2.3.4/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -568,7 +558,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" - test_d = self._make_get_request(b"matrix://[::1]/foo/bar") + test_d = self._make_get_request(b"matrix-federation://[::1]/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -604,7 +594,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # there will be a getaddrinfo on the IP self.reactor.lookups["::1"] = "::1" - test_d = self._make_get_request(b"matrix://[::1]:80/foo/bar") + test_d = self._make_get_request(b"matrix-federation://[::1]:80/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -636,10 +626,10 @@ class MatrixFederationAgentTests(unittest.TestCase): """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv1"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv1/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv1/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -661,9 +651,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # .well-known request fails. self.reactor.pump((0.4,)) - # now there should be a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv1" + # now there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv1"), call(b"_matrix._tcp.testserv1")] ) # we should fall back to a direct connection @@ -693,7 +683,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # there will be a getaddrinfo on the IP self.reactor.lookups["1.2.3.5"] = "1.2.3.5" - test_d = self._make_get_request(b"matrix://1.2.3.5/foo/bar") + test_d = self._make_get_request(b"matrix-federation://1.2.3.5/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -722,10 +712,10 @@ class MatrixFederationAgentTests(unittest.TestCase): """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -747,9 +737,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # .well-known request fails. self.reactor.pump((0.4,)) - # now there should be a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv" + # now there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")] ) # we should fall back to a direct connection @@ -776,11 +766,11 @@ class MatrixFederationAgentTests(unittest.TestCase): """Test the behaviour when the .well-known delegates elsewhere""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -798,9 +788,12 @@ class MatrixFederationAgentTests(unittest.TestCase): content=b'{ "m.server": "target-server" }', ) - # there should be a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server" + # there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [ + call(b"_matrix-fed._tcp.target-server"), + call(b"_matrix._tcp.target-server"), + ] ) # now we should get a connection to the target server @@ -840,11 +833,11 @@ class MatrixFederationAgentTests(unittest.TestCase): """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["target-server"] = "1::f" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -888,9 +881,12 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) - # there should be a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server" + # there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [ + call(b"_matrix-fed._tcp.target-server"), + call(b"_matrix._tcp.target-server"), + ] ) # now we should get a connection to the target server @@ -930,10 +926,10 @@ class MatrixFederationAgentTests(unittest.TestCase): """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -952,9 +948,9 @@ class MatrixFederationAgentTests(unittest.TestCase): client_factory, expected_sni=b"testserv", content=b"NOT JSON" ) - # now there should be a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv" + # now there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")] ) # we should fall back to a direct connection @@ -986,7 +982,7 @@ class MatrixFederationAgentTests(unittest.TestCase): # the config left to the default, which will not trust it (since the # presented cert is signed by a test CA) - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] self.reactor.lookups["testserv"] = "1.2.3.4" config = default_config("test", parse=True) @@ -1009,7 +1005,7 @@ class MatrixFederationAgentTests(unittest.TestCase): ), ) - test_d = agent.request(b"GET", b"matrix://testserv/foo/bar") + test_d = agent.request(b"GET", b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -1026,30 +1022,74 @@ class MatrixFederationAgentTests(unittest.TestCase): # there should be no requests self.assertEqual(len(http_proto.requests), 0) - # and there should be a SRV lookup instead - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv" + # and there should be two SRV lookups instead + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")] ) def test_get_hostname_srv(self) -> None: """ - Test the behaviour when there is a single SRV record + Test the behaviour when there is a single SRV record for _matrix-fed. """ self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"srvtarget", port=8443)] - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"srvtarget", port=8443) + ] self.reactor.lookups["srvtarget"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) # the request for a .well-known will have failed with a DNS lookup error. self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv" + b"_matrix-fed._tcp.testserv" + ) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_get_hostname_srv_legacy(self) -> None: + """ + Test the behaviour when there is a single SRV record for _matrix. + """ + self.agent = self._make_agent() + + # Return no entries for the _matrix-fed lookup, and a response for _matrix. + self.mock_resolver.resolve_service.side_effect = [ + [], + [Server(host=b"srvtarget", port=8443)], + ] + self.reactor.lookups["srvtarget"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # the request for a .well-known will have failed with a DNS lookup error. + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")] ) # Make sure treq is trying to connect @@ -1075,14 +1115,14 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_get_well_known_srv(self) -> None: """Test the behaviour when the .well-known redirects to a place where there - is a SRV. + is a _matrix-fed SRV record. """ self.agent = self._make_agent() self.reactor.lookups["testserv"] = "1.2.3.4" self.reactor.lookups["srvtarget"] = "5.6.7.8" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) @@ -1094,9 +1134,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertEqual(host, "1.2.3.4") self.assertEqual(port, 443) - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"srvtarget", port=8443)] - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"srvtarget", port=8443) + ] self._handle_well_known_connection( client_factory, @@ -1106,7 +1146,72 @@ class MatrixFederationAgentTests(unittest.TestCase): # there should be a SRV lookup self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.target-server" + b"_matrix-fed._tcp.target-server" + ) + + # now we should get a connection to the target of the SRV record + self.assertEqual(len(clients), 2) + (host, port, client_factory, _timeout, _bindAddress) = clients[1] + self.assertEqual(host, "5.6.7.8") + self.assertEqual(port, 8443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, expected_sni=b"target-server" + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"target-server"] + ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_get_well_known_srv_legacy(self) -> None: + """Test the behaviour when the .well-known redirects to a place where there + is a _matrix SRV record. + """ + self.agent = self._make_agent() + + self.reactor.lookups["testserv"] = "1.2.3.4" + self.reactor.lookups["srvtarget"] = "5.6.7.8" + + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # there should be an attempt to connect on port 443 for the .well-known + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 443) + + # Return no entries for the _matrix-fed lookup, and a response for _matrix. + self.mock_resolver.resolve_service.side_effect = [ + [], + [Server(host=b"srvtarget", port=8443)], + ] + + self._handle_well_known_connection( + client_factory, + expected_sni=b"testserv", + content=b'{ "m.server": "target-server" }', + ) + + # there should be two SRV lookups + self.mock_resolver.resolve_service.assert_has_calls( + [ + call(b"_matrix-fed._tcp.target-server"), + call(b"_matrix._tcp.target-server"), + ] ) # now we should get a connection to the target of the SRV record @@ -1137,13 +1242,15 @@ class MatrixFederationAgentTests(unittest.TestCase): """test the behaviour when the server name has idna chars in""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service([]) + self.mock_resolver.resolve_service.return_value = [] # the resolver is always called with the IDNA hostname as a native string. self.reactor.lookups["xn--bcher-kva.com"] = "1.2.3.4" # this is idna for bücher.com - test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") + test_d = self._make_get_request( + b"matrix-federation://xn--bcher-kva.com/foo/bar" + ) # Nothing happened yet self.assertNoResult(test_d) @@ -1166,8 +1273,11 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.4,)) # now there should have been a SRV lookup - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com" + self.mock_resolver.resolve_service.assert_has_calls( + [ + call(b"_matrix-fed._tcp.xn--bcher-kva.com"), + call(b"_matrix._tcp.xn--bcher-kva.com"), + ] ) # We should fall back to port 8448 @@ -1196,21 +1306,73 @@ class MatrixFederationAgentTests(unittest.TestCase): self.successResultOf(test_d) def test_idna_srv_target(self) -> None: - """test the behaviour when the target of a SRV record has idna chars""" + """test the behaviour when the target of a _matrix-fed SRV record has idna chars""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( - [Server(host=b"xn--trget-3qa.com", port=8443)] # târget.com - ) + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"xn--trget-3qa.com", port=8443) + ] # târget.com self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://xn--bcher-kva.com/foo/bar") + test_d = self._make_get_request( + b"matrix-federation://xn--bcher-kva.com/foo/bar" + ) # Nothing happened yet self.assertNoResult(test_d) self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.xn--bcher-kva.com" + b"_matrix-fed._tcp.xn--bcher-kva.com" + ) + + # Make sure treq is trying to connect + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients[0] + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # make a test server, and wire up the client + http_server = self._make_connection( + client_factory, expected_sni=b"xn--bcher-kva.com" + ) + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual( + request.requestHeaders.getRawHeaders(b"host"), [b"xn--bcher-kva.com"] + ) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_idna_srv_target_legacy(self) -> None: + """test the behaviour when the target of a _matrix SRV record has idna chars""" + self.agent = self._make_agent() + + # Return no entries for the _matrix-fed lookup, and a response for _matrix. + self.mock_resolver.resolve_service.side_effect = [ + [], + [Server(host=b"xn--trget-3qa.com", port=8443)], + ] # târget.com + self.reactor.lookups["xn--trget-3qa.com"] = "1.2.3.4" + + test_d = self._make_get_request( + b"matrix-federation://xn--bcher-kva.com/foo/bar" + ) + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_has_calls( + [ + call(b"_matrix-fed._tcp.xn--bcher-kva.com"), + call(b"_matrix._tcp.xn--bcher-kva.com"), + ] ) # Make sure treq is trying to connect @@ -1400,24 +1562,82 @@ class MatrixFederationAgentTests(unittest.TestCase): self.assertIsNone(r.delegated_server) def test_srv_fallbacks(self) -> None: - """Test that other SRV results are tried if the first one fails.""" + """Test that other SRV results are tried if the first one fails for _matrix-fed SRV.""" + self.agent = self._make_agent() + + self.mock_resolver.resolve_service.return_value = [ + Server(host=b"target.com", port=8443), + Server(host=b"target.com", port=8444), + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix-fed._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Hasn't failed yet + self.assertNoResult(test_d) + + # We shouldnow see an attempt to connect to the second server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8444) + + # make a test server, and wire up the client + http_server = self._make_connection(client_factory, expected_sni=b"testserv") + + self.assertEqual(len(http_server.requests), 1) + request = http_server.requests[0] + self.assertEqual(request.method, b"GET") + self.assertEqual(request.path, b"/foo/bar") + self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"testserv"]) + + # finish the request + request.finish() + self.reactor.pump((0.1,)) + self.successResultOf(test_d) + + def test_srv_fallbacks_legacy(self) -> None: + """Test that other SRV results are tried if the first one fails for _matrix SRV.""" self.agent = self._make_agent() - self.mock_resolver.resolve_service.side_effect = generate_resolve_service( + # Return no entries for the _matrix-fed lookup, and a response for _matrix. + self.mock_resolver.resolve_service.side_effect = [ + [], [ Server(host=b"target.com", port=8443), Server(host=b"target.com", port=8444), - ] - ) + ], + ] self.reactor.lookups["target.com"] = "1.2.3.4" - test_d = self._make_get_request(b"matrix://testserv/foo/bar") + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") # Nothing happened yet self.assertNoResult(test_d) - self.mock_resolver.resolve_service.assert_called_once_with( - b"_matrix._tcp.testserv" + self.mock_resolver.resolve_service.assert_has_calls( + [call(b"_matrix-fed._tcp.testserv"), call(b"_matrix._tcp.testserv")] ) # We should see an attempt to connect to the first server @@ -1457,6 +1677,43 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((0.1,)) self.successResultOf(test_d) + def test_srv_no_fallback_to_legacy(self) -> None: + """Test that _matrix SRV results are not tried if the _matrix-fed one fails.""" + self.agent = self._make_agent() + + # Return a failing entry for _matrix-fed. + self.mock_resolver.resolve_service.side_effect = [ + [Server(host=b"target.com", port=8443)], + [], + ] + self.reactor.lookups["target.com"] = "1.2.3.4" + + test_d = self._make_get_request(b"matrix-federation://testserv/foo/bar") + + # Nothing happened yet + self.assertNoResult(test_d) + + # Only the _matrix-fed is checked, _matrix is ignored. + self.mock_resolver.resolve_service.assert_called_once_with( + b"_matrix-fed._tcp.testserv" + ) + + # We should see an attempt to connect to the first server + clients = self.reactor.tcpClients + self.assertEqual(len(clients), 1) + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "1.2.3.4") + self.assertEqual(port, 8443) + + # Fonx the connection + client_factory.clientConnectionFailed(None, Exception("nope")) + + # There's a 300ms delay in HostnameEndpoint + self.reactor.pump((0.4,)) + + # Failed to resolve a server. + self.assertFailure(test_d, Exception) + class TestCachePeriodFromHeaders(unittest.TestCase): def test_cache_control(self) -> None: diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py index 0dfc03ce50..ab94f3f67a 100644 --- a/tests/http/test_matrixfederationclient.py +++ b/tests/http/test_matrixfederationclient.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Generator -from unittest.mock import Mock +from typing import Any, Dict, Generator +from unittest.mock import ANY, Mock, create_autospec from netaddr import IPSet from parameterized import parameterized @@ -21,10 +21,12 @@ from twisted.internet import defer from twisted.internet.defer import Deferred, TimeoutError from twisted.internet.error import ConnectingCancelledError, DNSLookupError from twisted.test.proto_helpers import MemoryReactor, StringTransport -from twisted.web.client import ResponseNeverReceived +from twisted.web.client import Agent, ResponseNeverReceived from twisted.web.http import HTTPChannel +from twisted.web.http_headers import Headers -from synapse.api.errors import RequestSendFailed +from synapse.api.errors import HttpResponseException, RequestSendFailed +from synapse.config._base import ConfigError from synapse.http.matrixfederationclient import ( ByteParser, MatrixFederationHttpClient, @@ -39,8 +41,10 @@ from synapse.logging.context import ( from synapse.server import HomeServer from synapse.util import Clock +from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import FakeTransport -from tests.unittest import HomeserverTestCase +from tests.test_utils import FakeResponse +from tests.unittest import HomeserverTestCase, override_config def check_logcontext(context: LoggingContextOrSentinel) -> None: @@ -640,3 +644,293 @@ class FederationClientTests(HomeserverTestCase): self.cl.build_auth_headers( b"", b"GET", b"https://example.com", destination_is=b"" ) + + @override_config( + { + "federation": { + "client_timeout": "180s", + "max_long_retry_delay": "100s", + "max_short_retry_delay": "7s", + "max_long_retries": 20, + "max_short_retries": 5, + } + } + ) + def test_configurable_retry_and_delay_values(self) -> None: + self.assertEqual(self.cl.default_timeout_seconds, 180) + self.assertEqual(self.cl.max_long_retry_delay_seconds, 100) + self.assertEqual(self.cl.max_short_retry_delay_seconds, 7) + self.assertEqual(self.cl.max_long_retries, 20) + self.assertEqual(self.cl.max_short_retries, 5) + + +class FederationClientProxyTests(BaseMultiWorkerStreamTestCase): + def default_config(self) -> Dict[str, Any]: + conf = super().default_config() + conf["instance_map"] = { + "main": {"host": "testserv", "port": 8765}, + "federation_sender": {"host": "testserv", "port": 1001}, + } + return conf + + @override_config( + { + "outbound_federation_restricted_to": ["federation_sender"], + "worker_replication_secret": "secret", + } + ) + def test_proxy_requests_through_federation_sender_worker(self) -> None: + """ + Test that all outbound federation requests go through the `federation_sender` + worker + """ + # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance + # so we can act like some remote server responding to requests + mock_client_on_federation_sender = Mock() + mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True) + mock_client_on_federation_sender.agent = mock_agent_on_federation_sender + + # Create the `federation_sender` worker + self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "federation_sender"}, + federation_http_client=mock_client_on_federation_sender, + ) + + # Fake `remoteserv:8008` responding to requests + mock_agent_on_federation_sender.request.side_effect = ( + lambda *args, **kwargs: defer.succeed( + FakeResponse.json( + payload={ + "foo": "bar", + } + ) + ) + ) + + # This federation request from the main process should be proxied through the + # `federation_sender` worker off to the remote server + test_request_from_main_process_d = defer.ensureDeferred( + self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar") + ) + + # Pump the reactor so our deferred goes through the motions + self.pump() + + # Make sure that the request was proxied through the `federation_sender` worker + mock_agent_on_federation_sender.request.assert_called_once_with( + b"GET", + b"matrix-federation://remoteserv:8008/foo/bar", + headers=ANY, + bodyProducer=ANY, + ) + + # Make sure the response is as expected back on the main worker + res = self.successResultOf(test_request_from_main_process_d) + self.assertEqual(res, {"foo": "bar"}) + + @override_config( + { + "outbound_federation_restricted_to": ["federation_sender"], + "worker_replication_secret": "secret", + } + ) + def test_proxy_request_with_network_error_through_federation_sender_worker( + self, + ) -> None: + """ + Test that when the outbound federation request fails with a network related + error, a sensible error makes its way back to the main process. + """ + # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance + # so we can act like some remote server responding to requests + mock_client_on_federation_sender = Mock() + mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True) + mock_client_on_federation_sender.agent = mock_agent_on_federation_sender + + # Create the `federation_sender` worker + self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "federation_sender"}, + federation_http_client=mock_client_on_federation_sender, + ) + + # Fake `remoteserv:8008` responding to requests + mock_agent_on_federation_sender.request.side_effect = ( + lambda *args, **kwargs: defer.fail(ResponseNeverReceived("fake error")) + ) + + # This federation request from the main process should be proxied through the + # `federation_sender` worker off to the remote server + test_request_from_main_process_d = defer.ensureDeferred( + self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar") + ) + + # Pump the reactor so our deferred goes through the motions. We pump with 10 + # seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries + # and finally passes along the error response. + self.pump(0.1) + + # Make sure that the request was proxied through the `federation_sender` worker + mock_agent_on_federation_sender.request.assert_called_with( + b"GET", + b"matrix-federation://remoteserv:8008/foo/bar", + headers=ANY, + bodyProducer=ANY, + ) + + # Make sure we get some sort of error back on the main worker + failure_res = self.failureResultOf(test_request_from_main_process_d) + self.assertIsInstance(failure_res.value, RequestSendFailed) + self.assertIsInstance(failure_res.value.inner_exception, HttpResponseException) + self.assertEqual(failure_res.value.inner_exception.code, 502) + + @override_config( + { + "outbound_federation_restricted_to": ["federation_sender"], + "worker_replication_secret": "secret", + } + ) + def test_proxy_requests_and_discards_hop_by_hop_headers(self) -> None: + """ + Test to make sure hop-by-hop headers and addional headers defined in the + `Connection` header are discarded when proxying requests + """ + # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance + # so we can act like some remote server responding to requests + mock_client_on_federation_sender = Mock() + mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True) + mock_client_on_federation_sender.agent = mock_agent_on_federation_sender + + # Create the `federation_sender` worker + self.make_worker_hs( + "synapse.app.generic_worker", + {"worker_name": "federation_sender"}, + federation_http_client=mock_client_on_federation_sender, + ) + + # Fake `remoteserv:8008` responding to requests + mock_agent_on_federation_sender.request.side_effect = lambda *args, **kwargs: defer.succeed( + FakeResponse( + code=200, + body=b'{"foo": "bar"}', + headers=Headers( + { + "Content-Type": ["application/json"], + "Connection": ["close, X-Foo, X-Bar"], + # Should be removed because it's defined in the `Connection` header + "X-Foo": ["foo"], + "X-Bar": ["bar"], + # Should be removed because it's a hop-by-hop header + "Proxy-Authorization": "abcdef", + } + ), + ) + ) + + # This federation request from the main process should be proxied through the + # `federation_sender` worker off to the remote server + test_request_from_main_process_d = defer.ensureDeferred( + self.hs.get_federation_http_client().get_json_with_headers( + "remoteserv:8008", "foo/bar" + ) + ) + + # Pump the reactor so our deferred goes through the motions + self.pump() + + # Make sure that the request was proxied through the `federation_sender` worker + mock_agent_on_federation_sender.request.assert_called_once_with( + b"GET", + b"matrix-federation://remoteserv:8008/foo/bar", + headers=ANY, + bodyProducer=ANY, + ) + + res, headers = self.successResultOf(test_request_from_main_process_d) + header_names = set(headers.keys()) + + # Make sure the response does not include the hop-by-hop headers + self.assertNotIn(b"X-Foo", header_names) + self.assertNotIn(b"X-Bar", header_names) + self.assertNotIn(b"Proxy-Authorization", header_names) + # Make sure the response is as expected back on the main worker + self.assertEqual(res, {"foo": "bar"}) + + @override_config( + { + "outbound_federation_restricted_to": ["federation_sender"], + # `worker_replication_secret` is set here so that the test setup is able to pass + # but the actual homserver creation test is in the test body below + "worker_replication_secret": "secret", + } + ) + def test_not_able_to_proxy_requests_through_federation_sender_worker_when_no_secret_configured( + self, + ) -> None: + """ + Test that we aren't able to proxy any outbound federation requests when + `worker_replication_secret` is not configured. + """ + with self.assertRaises(ConfigError): + # Create the `federation_sender` worker + self.make_worker_hs( + "synapse.app.generic_worker", + { + "worker_name": "federation_sender", + # Test that we aren't able to proxy any outbound federation requests + # when `worker_replication_secret` is not configured. + "worker_replication_secret": None, + }, + ) + + @override_config( + { + "outbound_federation_restricted_to": ["federation_sender"], + "worker_replication_secret": "secret", + } + ) + def test_not_able_to_proxy_requests_through_federation_sender_worker_when_wrong_auth_given( + self, + ) -> None: + """ + Test that we aren't able to proxy any outbound federation requests when the + wrong authorization is given. + """ + # Mock out the `MatrixFederationHttpClient` of the `federation_sender` instance + # so we can act like some remote server responding to requests + mock_client_on_federation_sender = Mock() + mock_agent_on_federation_sender = create_autospec(Agent, spec_set=True) + mock_client_on_federation_sender.agent = mock_agent_on_federation_sender + + # Create the `federation_sender` worker + self.make_worker_hs( + "synapse.app.generic_worker", + { + "worker_name": "federation_sender", + # Test that we aren't able to proxy any outbound federation requests + # when `worker_replication_secret` is wrong. + "worker_replication_secret": "wrong", + }, + federation_http_client=mock_client_on_federation_sender, + ) + + # This federation request from the main process should be proxied through the + # `federation_sender` worker off but will fail here because it's using the wrong + # authorization. + test_request_from_main_process_d = defer.ensureDeferred( + self.hs.get_federation_http_client().get_json("remoteserv:8008", "foo/bar") + ) + + # Pump the reactor so our deferred goes through the motions. We pump with 10 + # seconds (0.1 * 100) so the `MatrixFederationHttpClient` runs out of retries + # and finally passes along the error response. + self.pump(0.1) + + # Make sure that the request was *NOT* proxied through the `federation_sender` + # worker + mock_agent_on_federation_sender.request.assert_not_called() + + failure_res = self.failureResultOf(test_request_from_main_process_d) + self.assertIsInstance(failure_res.value, HttpResponseException) + self.assertEqual(failure_res.value.code, 401) diff --git a/tests/http/test_proxy.py b/tests/http/test_proxy.py new file mode 100644 index 0000000000..0dc9ba8e05 --- /dev/null +++ b/tests/http/test_proxy.py @@ -0,0 +1,53 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Set + +from parameterized import parameterized + +from synapse.http.proxy import parse_connection_header_value + +from tests.unittest import TestCase + + +class ProxyTests(TestCase): + @parameterized.expand( + [ + [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], + # No whitespace + [b"close,X-Foo,X-Bar", {"Close", "X-Foo", "X-Bar"}], + # More whitespace + [b"close, X-Foo, X-Bar", {"Close", "X-Foo", "X-Bar"}], + # "close" directive in not the first position + [b"X-Foo, X-Bar, close", {"X-Foo", "X-Bar", "Close"}], + # Normalizes header capitalization + [b"keep-alive, x-fOo, x-bAr", {"Keep-Alive", "X-Foo", "X-Bar"}], + # Handles header names with whitespace + [ + b"keep-alive, x foo, x bar", + {"Keep-Alive", "X foo", "X bar"}, + ], + ] + ) + def test_parse_connection_header_value( + self, + connection_header_value: bytes, + expected_extra_headers_to_remove: Set[str], + ) -> None: + """ + Tests that the connection header value is parsed correctly + """ + self.assertEqual( + expected_extra_headers_to_remove, + parse_connection_header_value(connection_header_value), + ) diff --git a/tests/http/test_proxyagent.py b/tests/http/test_proxyagent.py index e0ae5a88ff..8164b0b78e 100644 --- a/tests/http/test_proxyagent.py +++ b/tests/http/test_proxyagent.py @@ -33,7 +33,7 @@ from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from synapse.http.client import BlocklistingReactorWrapper -from synapse.http.connectproxyclient import ProxyCredentials +from synapse.http.connectproxyclient import BasicProxyCredentials from synapse.http.proxyagent import ProxyAgent, parse_proxy from tests.http import ( @@ -205,7 +205,7 @@ class ProxyParserTests(TestCase): """ proxy_cred = None if expected_credentials: - proxy_cred = ProxyCredentials(expected_credentials) + proxy_cred = BasicProxyCredentials(expected_credentials) self.assertEqual( ( expected_scheme, diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index e28ba84cc2..1bc7d64ad9 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import cast +from typing import Awaitable, cast from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactorClock @@ -227,8 +227,6 @@ class LogContextScopeManagerTestCase(TestCase): Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` with functions that return deferreds """ - reactor = MemoryReactorClock() - with LoggingContext("root context"): @trace_with_opname("fixture_deferred_func", tracer=self._tracer) @@ -240,9 +238,6 @@ class LogContextScopeManagerTestCase(TestCase): result_d1 = fixture_deferred_func() - # let the tasks complete - reactor.pump((2,) * 8) - self.assertEqual(self.successResultOf(result_d1), "foo") # the span should have been reported @@ -256,8 +251,6 @@ class LogContextScopeManagerTestCase(TestCase): Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` with async functions """ - reactor = MemoryReactorClock() - with LoggingContext("root context"): @trace_with_opname("fixture_async_func", tracer=self._tracer) @@ -267,9 +260,6 @@ class LogContextScopeManagerTestCase(TestCase): d1 = defer.ensureDeferred(fixture_async_func()) - # let the tasks complete - reactor.pump((2,) * 8) - self.assertEqual(self.successResultOf(d1), "foo") # the span should have been reported @@ -277,3 +267,34 @@ class LogContextScopeManagerTestCase(TestCase): [span.operation_name for span in self._reporter.get_spans()], ["fixture_async_func"], ) + + def test_trace_decorator_awaitable_return(self) -> None: + """ + Test whether we can use `@trace_with_opname` (`@trace`) and `@tag_args` + with functions that return an awaitable (e.g. a coroutine) + """ + with LoggingContext("root context"): + # Something we can return without `await` to get a coroutine + async def fixture_async_func() -> str: + return "foo" + + # The actual kind of function we want to test that returns an awaitable + @trace_with_opname("fixture_awaitable_return_func", tracer=self._tracer) + @tag_args + def fixture_awaitable_return_func() -> Awaitable[str]: + return fixture_async_func() + + # Something we can run with `defer.ensureDeferred(runner())` and pump the + # whole async tasks through to completion. + async def runner() -> str: + return await fixture_awaitable_return_func() + + d1 = defer.ensureDeferred(runner()) + + self.assertEqual(self.successResultOf(d1), "foo") + + # the span should have been reported + self.assertEqual( + [span.operation_name for span in self._reporter.get_spans()], + ["fixture_awaitable_return_func"], + ) diff --git a/tests/logging/test_remote_handler.py b/tests/logging/test_remote_handler.py index 5191e31a8a..45eac100bf 100644 --- a/tests/logging/test_remote_handler.py +++ b/tests/logging/test_remote_handler.py @@ -78,11 +78,11 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logger = self.get_logger(handler) # Send some debug messages - for i in range(0, 3): + for i in range(3): logger.debug("debug %s" % (i,)) # Send a bunch of useful messages - for i in range(0, 7): + for i in range(7): logger.info("info %s" % (i,)) # The last debug message pushes it past the maximum buffer @@ -108,15 +108,15 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logger = self.get_logger(handler) # Send some debug messages - for i in range(0, 3): + for i in range(3): logger.debug("debug %s" % (i,)) # Send a bunch of useful messages - for i in range(0, 10): + for i in range(10): logger.warning("warn %s" % (i,)) # Send a bunch of info messages - for i in range(0, 3): + for i in range(3): logger.info("info %s" % (i,)) # The last debug message pushes it past the maximum buffer @@ -144,7 +144,7 @@ class RemoteHandlerTestCase(LoggerCleanupMixin, TestCase): logger = self.get_logger(handler) # Send a bunch of useful messages - for i in range(0, 20): + for i in range(20): logger.warning("warn %s" % (i,)) # Allow the reconnection diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index fa27f1279a..c379853e20 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -164,7 +164,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some internal processing of SynapseRequest. - request._started_processing = Mock() # type: ignore[assignment] + request._started_processing = Mock() # type: ignore[method-assign] request.request_metrics = Mock(spec=["name"]) with patch.object(Request, "render"): request.requestReceived(b"POST", b"/_matrix/client/versions", b"1.1") diff --git a/tests/media/test_base.py b/tests/media/test_base.py index 66498c744d..119d7ba66f 100644 --- a/tests/media/test_base.py +++ b/tests/media/test_base.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.media._base import get_filename_from_headers +from unittest.mock import Mock + +from synapse.media._base import add_file_headers, get_filename_from_headers from tests import unittest @@ -20,12 +22,12 @@ from tests import unittest class GetFileNameFromHeadersTests(unittest.TestCase): # input -> expected result TEST_CASES = { - b"inline; filename=abc.txt": "abc.txt", - b'inline; filename="azerty"': "azerty", - b'inline; filename="aze%20rty"': "aze%20rty", - b'inline; filename="aze"rty"': 'aze"rty', - b'inline; filename="azer;ty"': "azer;ty", - b"inline; filename*=utf-8''foo%C2%A3bar": "foo£bar", + b"attachment; filename=abc.txt": "abc.txt", + b'attachment; filename="azerty"': "azerty", + b'attachment; filename="aze%20rty"': "aze%20rty", + b'attachment; filename="aze"rty"': 'aze"rty', + b'attachment; filename="azer;ty"': "azer;ty", + b"attachment; filename*=utf-8''foo%C2%A3bar": "foo£bar", } def tests(self) -> None: @@ -36,3 +38,28 @@ class GetFileNameFromHeadersTests(unittest.TestCase): expected, f"expected output for {hdr!r} to be {expected} but was {res}", ) + + +class AddFileHeadersTests(unittest.TestCase): + TEST_CASES = { + "text/plain": b"inline; filename=file.name", + "text/csv": b"inline; filename=file.name", + "image/png": b"inline; filename=file.name", + "text/html": b"attachment; filename=file.name", + "any/thing": b"attachment; filename=file.name", + } + + def test_content_disposition(self) -> None: + for media_type, expected in self.TEST_CASES.items(): + request = Mock() + add_file_headers(request, media_type, 0, "file.name") + request.setHeader.assert_any_call(b"Content-Disposition", expected) + + def test_no_filename(self) -> None: + request = Mock() + add_file_headers(request, "text/plain", 0, None) + request.setHeader.assert_any_call(b"Content-Disposition", b"inline") + + request.reset_mock() + add_file_headers(request, "text/html", 0, None) + request.setHeader.assert_any_call(b"Content-Disposition", b"attachment") diff --git a/tests/media/test_html_preview.py b/tests/media/test_html_preview.py index e7da75db3e..ea84bb3d3d 100644 --- a/tests/media/test_html_preview.py +++ b/tests/media/test_html_preview.py @@ -24,7 +24,7 @@ from tests import unittest try: import lxml except ImportError: - lxml = None + lxml = None # type: ignore[assignment] class SummarizeTestCase(unittest.TestCase): @@ -160,6 +160,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -176,6 +177,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -195,6 +197,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual( @@ -217,6 +220,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -231,6 +235,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -246,6 +251,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) @@ -261,6 +267,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -281,6 +288,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) @@ -296,6 +304,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -324,6 +333,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): <head><title>Foo</title></head><body>Some text.</body></html> """.strip() tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -338,6 +348,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -353,6 +364,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) @@ -367,6 +379,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) @@ -380,6 +393,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -401,6 +415,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -419,6 +434,7 @@ class OpenGraphFromHtmlTestCase(unittest.TestCase): with a cheeky SVG</svg></u> and <strong>some</strong> tail text</b></a> """ tree = decode_body(html, "http://example.com/test.html") + assert tree is not None og = parse_html_to_open_graph(tree) self.assertEqual( og, diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index f0f2da65db..15f5d644e4 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -28,12 +28,13 @@ from typing_extensions import Literal from twisted.internet import defer from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource from synapse.api.errors import Codes from synapse.events import EventBase from synapse.http.types import QueryParams from synapse.logging.context import make_deferred_yieldable -from synapse.media._base import FileInfo +from synapse.media._base import FileInfo, ThumbnailInfo from synapse.media.filepath import MediaFilePaths from synapse.media.media_storage import MediaStorage, ReadableFileWrapper from synapse.media.storage_provider import FileStorageProviderBackend @@ -41,12 +42,13 @@ from synapse.module_api import ModuleApi from synapse.module_api.callbacks.spamchecker_callbacks import load_legacy_spam_checkers from synapse.rest import admin from synapse.rest.client import login +from synapse.rest.media.thumbnail_resource import ThumbnailResource from synapse.server import HomeServer from synapse.types import JsonDict, RoomAlias from synapse.util import Clock from tests import unittest -from tests.server import FakeChannel, FakeSite, make_request +from tests.server import FakeChannel from tests.test_utils import SMALL_PNG from tests.utils import default_config @@ -129,6 +131,8 @@ class _TestImage: a 404/400 is expected. unable_to_thumbnail: True if we expect the thumbnailing to fail (400), or False if the thumbnailing should succeed or a normal 404 is expected. + is_inline: True if we expect the file to be served using an inline + Content-Disposition or False if we expect an attachment. """ data: bytes @@ -138,6 +142,7 @@ class _TestImage: expected_scaled: Optional[bytes] = None expected_found: bool = True unable_to_thumbnail: bool = False + is_inline: bool = True @parameterized_class( @@ -198,6 +203,25 @@ class _TestImage: unable_to_thumbnail=True, ), ), + # An SVG. + ( + _TestImage( + b"""<?xml version="1.0"?> +<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN" + "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd"> + +<svg xmlns="http://www.w3.org/2000/svg" + width="400" height="400"> + <circle cx="100" cy="100" r="50" stroke="black" + stroke-width="5" fill="red" /> +</svg>""", + b"image/svg", + b".svg", + expected_found=False, + unable_to_thumbnail=True, + is_inline=False, + ), + ), ], ) class MediaRepoTests(unittest.HomeserverTestCase): @@ -266,22 +290,22 @@ class MediaRepoTests(unittest.HomeserverTestCase): return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_resource = hs.get_media_repository_resource() - self.download_resource = media_resource.children[b"download"] - self.thumbnail_resource = media_resource.children[b"thumbnail"] self.store = hs.get_datastores().main self.media_repo = hs.get_media_repository() self.media_id = "example.com/12345" + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def _req( self, content_disposition: Optional[bytes], include_content_type: bool = True ) -> FakeChannel: - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id, + f"/_matrix/media/v3/download/{self.media_id}", shorthand=False, await_result=False, ) @@ -317,7 +341,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): def test_handle_missing_content_type(self) -> None: channel = self._req( - b"inline; filename=out" + self.test_image.extension, + b"attachment; filename=out" + self.test_image.extension, include_content_type=False, ) headers = channel.headers @@ -331,7 +355,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): If the filename is filename=<ascii> then Synapse will decode it as an ASCII string, and use filename= in the response. """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) + channel = self._req(b"attachment; filename=out" + self.test_image.extension) headers = channel.headers self.assertEqual( @@ -339,7 +363,11 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.assertEqual( headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename=out" + self.test_image.extension], + [ + (b"inline" if self.test_image.is_inline else b"attachment") + + b"; filename=out" + + self.test_image.extension + ], ) def test_disposition_filenamestar_utf8escaped(self) -> None: @@ -350,7 +378,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ filename = parse.quote("\u2603".encode()).encode("ascii") channel = self._req( - b"inline; filename*=utf-8''" + filename + self.test_image.extension + b"attachment; filename*=utf-8''" + filename + self.test_image.extension ) headers = channel.headers @@ -359,13 +387,18 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) self.assertEqual( headers.getRawHeaders(b"Content-Disposition"), - [b"inline; filename*=utf-8''" + filename + self.test_image.extension], + [ + (b"inline" if self.test_image.is_inline else b"attachment") + + b"; filename*=utf-8''" + + filename + + self.test_image.extension + ], ) def test_disposition_none(self) -> None: """ - If there is no filename, one isn't passed on in the Content-Disposition - of the request. + If there is no filename, Content-Disposition should only + be a disposition type. """ channel = self._req(None) @@ -373,7 +406,10 @@ class MediaRepoTests(unittest.HomeserverTestCase): self.assertEqual( headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type] ) - self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) + self.assertEqual( + headers.getRawHeaders(b"Content-Disposition"), + [b"inline" if self.test_image.is_inline else b"attachment"], + ) def test_thumbnail_crop(self) -> None: """Test that a cropped remote thumbnail is available.""" @@ -447,11 +483,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): # Fetching again should work, without re-requesting the image from the # remote. params = "?width=32&height=32&method=scale" - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -477,11 +511,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): ) shutil.rmtree(thumbnail_dir, ignore_errors=True) - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/v3/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -515,11 +547,9 @@ class MediaRepoTests(unittest.HomeserverTestCase): """ params = "?width=32&height=32&method=" + method - channel = make_request( - self.reactor, - FakeSite(self.thumbnail_resource, self.reactor), + channel = self.make_request( "GET", - self.media_id + params, + f"/_matrix/media/r0/thumbnail/{self.media_id}{params}", shorthand=False, await_result=False, ) @@ -556,7 +586,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_UNKNOWN", - "error": "Cannot find any thumbnails for the requested media ([b'example.com', b'12345']). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", + "error": "Cannot find any thumbnails for the requested media ('/_matrix/media/r0/thumbnail/example.com/12345'). This might mean the media is not a supported_media_format=(image/jpeg, image/jpg, image/webp, image/gif, image/png) or that thumbnailing failed for some other reason. (Dynamic thumbnails are disabled on this server.)", }, ) else: @@ -566,7 +596,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel.json_body, { "errcode": "M_NOT_FOUND", - "error": "Not found [b'example.com', b'12345']", + "error": "Not found '/_matrix/media/r0/thumbnail/example.com/12345'", }, ) @@ -575,34 +605,39 @@ class MediaRepoTests(unittest.HomeserverTestCase): """Test that choosing between thumbnails with the same quality rating succeeds. We are not particular about which thumbnail is chosen.""" + + content_type = self.test_image.content_type.decode() + media_repo = self.hs.get_media_repository() + thumbnail_resouce = ThumbnailResource( + self.hs, media_repo, media_repo.media_storage + ) + self.assertIsNotNone( - self.thumbnail_resource._select_thumbnail( + thumbnail_resouce._select_thumbnail( desired_width=desired_size, desired_height=desired_size, desired_method=method, - desired_type=self.test_image.content_type, + desired_type=content_type, # Provide two identical thumbnails which are guaranteed to have the same # quality rating. thumbnail_infos=[ - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail1{self.test_image.extension.decode()}", - }, - { - "thumbnail_width": 32, - "thumbnail_height": 32, - "thumbnail_method": method, - "thumbnail_type": self.test_image.content_type, - "thumbnail_length": 256, - "filesystem_id": f"thumbnail2{self.test_image.extension.decode()}", - }, + ThumbnailInfo( + width=32, + height=32, + method=method, + type=content_type, + length=256, + ), + ThumbnailInfo( + width=32, + height=32, + method=method, + type=content_type, + length=256, + ), ], file_id=f"image{self.test_image.extension.decode()}", - url_cache=None, + url_cache=False, server_name=None, ) ) @@ -612,7 +647,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): Tests that the `X-Robots-Tag` header is present, which informs web crawlers to not index, archive, or follow links in media. """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) + channel = self._req(b"attachment; filename=out" + self.test_image.extension) headers = channel.headers self.assertEqual( @@ -625,7 +660,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): Test that the Cross-Origin-Resource-Policy header is set to "cross-origin" allowing web clients to embed media from the downloads API. """ - channel = self._req(b"inline; filename=out" + self.test_image.extension) + channel = self._req(b"attachment; filename=out" + self.test_image.extension) headers = channel.headers @@ -691,13 +726,13 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - load_legacy_spam_checkers(hs) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def default_config(self) -> Dict[str, Any]: config = default_config("test") @@ -717,9 +752,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200) def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should @@ -728,9 +761,7 @@ class SpamCheckerTestCaseLegacy(unittest.HomeserverTestCase): data = b"Some evil data" - self.helper.upload_media( - self.upload_resource, data, tok=self.tok, expect_code=400 - ) + self.helper.upload_media(data, tok=self.tok, expect_code=400) EVIL_DATA = b"Some evil data" @@ -747,15 +778,15 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): self.user = self.register_user("user", "pass") self.tok = self.login("user", "pass") - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] - hs.get_module_api().register_spam_checker_callbacks( check_media_file_for_spam=self.check_media_file_for_spam ) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + async def check_media_file_for_spam( self, file_wrapper: ReadableFileWrapper, file_info: FileInfo ) -> Union[Codes, Literal["NOT_SPAM"], Tuple[Codes, JsonDict]]: @@ -771,21 +802,16 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase): def test_upload_innocent(self) -> None: """Attempt to upload some innocent data that should be allowed.""" - self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=self.tok, expect_code=200) def test_upload_ban(self) -> None: """Attempt to upload some data that includes bytes "evil", which should get rejected by the spam checker. """ - self.helper.upload_media( - self.upload_resource, EVIL_DATA, tok=self.tok, expect_code=400 - ) + self.helper.upload_media(EVIL_DATA, tok=self.tok, expect_code=400) self.helper.upload_media( - self.upload_resource, EVIL_DATA_EXPERIMENT, tok=self.tok, expect_code=400, diff --git a/tests/media/test_oembed.py b/tests/media/test_oembed.py index c8bf8421da..3bc19cb1cc 100644 --- a/tests/media/test_oembed.py +++ b/tests/media/test_oembed.py @@ -28,7 +28,7 @@ from tests.unittest import HomeserverTestCase try: import lxml except ImportError: - lxml = None + lxml = None # type: ignore[assignment] class OEmbedTests(HomeserverTestCase): diff --git a/tests/media/test_url_previewer.py b/tests/media/test_url_previewer.py index 3c4c7d6765..04b69f378a 100644 --- a/tests/media/test_url_previewer.py +++ b/tests/media/test_url_previewer.py @@ -24,7 +24,7 @@ from tests.unittest import override_config try: import lxml except ImportError: - lxml = None + lxml = None # type: ignore[assignment] class URLPreviewTests(unittest.HomeserverTestCase): @@ -61,9 +61,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): return self.setup_test_homeserver(config=config) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo_resource = hs.get_media_repository_resource() - preview_url = media_repo_resource.children[b"preview_url"] - self.url_previewer = preview_url._url_previewer + media_repo = hs.get_media_repository() + assert media_repo.url_previewer is not None + self.url_previewer = media_repo.url_previewer def test_all_urls_allowed(self) -> None: self.assertFalse(self.url_previewer._is_url_blocked("http://matrix.org")) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 7c3656d049..d14876826c 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -12,19 +12,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from importlib import metadata from typing import Dict, Tuple - -from typing_extensions import Protocol - -try: - from importlib import metadata -except ImportError: - import importlib_metadata as metadata # type: ignore[no-redef] - from unittest.mock import patch from pkg_resources import parse_version from prometheus_client.core import Sample +from typing_extensions import Protocol from synapse.app._base import _set_prometheus_client_use_created_metrics from synapse.metrics import REGISTRY, InFlightGauge, generate_latest diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index bff7114cd8..172fc3a736 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor @@ -28,12 +28,11 @@ from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, notifications, presence, profile, room from synapse.server import HomeServer -from synapse.types import JsonDict, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests.events.test_presence_router import send_presence_update, sync_presence from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.test_utils import simple_async_mock from tests.test_utils.event_injection import inject_member_event from tests.unittest import HomeserverTestCase, override_config @@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. self.fed_transport_client = Mock(spec=["send_transaction"]) - self.fed_transport_client.send_transaction = simple_async_mock({}) + self.fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=self.fed_transport_client, @@ -103,7 +102,9 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(email["added_at"], 0) # Check that the displayname was assigned - displayname = self.get_success(self.store.get_profile_displayname("bob")) + displayname = self.get_success( + self.store.get_profile_displayname(UserID.from_string("@bob:test")) + ) self.assertEqual(displayname, "Bobberino") def test_can_register_admin_user(self) -> None: @@ -232,7 +233,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): def test_sending_events_into_room(self) -> None: """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent - self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[assignment] + self.event_creation_handler.create_and_send_nonmember_event = Mock( # type: ignore[method-assign] spec=[], side_effect=self.event_creation_handler.create_and_send_nonmember_event, ) @@ -577,10 +578,8 @@ class ModuleApiTestCase(BaseModuleApiTestCase): """Test that the module API can join a remote room.""" # Necessary to fake a remote join. fake_stream_id = 1 - mocked_remote_join = simple_async_mock( - return_value=("fake-event-id", fake_stream_id) - ) - self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[assignment] + mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id)) + self.hs.get_room_member_handler()._remote_join = mocked_remote_join # type: ignore[method-assign] fake_remote_host = f"{self.module_api.server_name}-remote" # Given that the join is to be faked, we expect the relevant join event not to @@ -755,7 +754,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase): self.assertEqual(channel.json_body["creator"], user_id) # Check room alias. - self.assertEquals(room_alias, f"#foo-bar:{self.module_api.server_name}") + self.assertEqual(room_alias, f"#foo-bar:{self.module_api.server_name}") # Let's try a room with no alias. room_id, room_alias = self.get_success( diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py index 9501096a77..7c23b77e0a 100644 --- a/tests/push/test_bulk_push_rule_evaluator.py +++ b/tests/push/test_bulk_push_rule_evaluator.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Any, Optional -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from parameterized import parameterized @@ -28,7 +28,6 @@ from synapse.server import HomeServer from synapse.types import JsonDict, create_requester from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Mock the method which calculates push rules -- we do this instead of # e.g. checking the results in the database because we want to ensure # that code isn't even running. - bulk_evaluator._action_for_event_by_user = simple_async_mock() # type: ignore[assignment] + bulk_evaluator._action_for_event_by_user = AsyncMock() # type: ignore[method-assign] # Ensure no actions are generated! self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)])) @@ -228,7 +227,6 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) return len(result) > 0 - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) def test_user_mentions(self) -> None: """Test the behavior of an event which includes invalid user mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -237,9 +235,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.assertFalse(self._create_and_process(bulk_evaluator)) # An empty mentions field should not notify. self.assertFalse( - self._create_and_process( - bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {}} - ) + self._create_and_process(bulk_evaluator, {EventContentFields.MENTIONS: {}}) ) # Non-dict mentions should be ignored. @@ -253,7 +249,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): for mentions in (None, True, False, 1, "foo", []): self.assertFalse( self._create_and_process( - bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: mentions} + bulk_evaluator, {EventContentFields.MENTIONS: mentions} ) ) @@ -262,7 +258,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.assertFalse( self._create_and_process( bulk_evaluator, - {EventContentFields.MSC3952_MENTIONS: {"user_ids": mentions}}, + {EventContentFields.MENTIONS: {"user_ids": mentions}}, ) ) @@ -270,14 +266,14 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.assertTrue( self._create_and_process( bulk_evaluator, - {EventContentFields.MSC3952_MENTIONS: {"user_ids": [self.alice]}}, + {EventContentFields.MENTIONS: {"user_ids": [self.alice]}}, ) ) self.assertTrue( self._create_and_process( bulk_evaluator, { - EventContentFields.MSC3952_MENTIONS: { + EventContentFields.MENTIONS: { "user_ids": ["@another:test", self.alice] } }, @@ -288,11 +284,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.assertTrue( self._create_and_process( bulk_evaluator, - { - EventContentFields.MSC3952_MENTIONS: { - "user_ids": [self.alice, self.alice] - } - }, + {EventContentFields.MENTIONS: {"user_ids": [self.alice, self.alice]}}, ) ) @@ -307,7 +299,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self._create_and_process( bulk_evaluator, { - EventContentFields.MSC3952_MENTIONS: { + EventContentFields.MENTIONS: { "user_ids": [None, True, False, {}, []] } }, @@ -317,7 +309,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self._create_and_process( bulk_evaluator, { - EventContentFields.MSC3952_MENTIONS: { + EventContentFields.MENTIONS: { "user_ids": [None, True, False, {}, [], self.alice] } }, @@ -331,12 +323,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): { "body": self.alice, "msgtype": "m.text", - EventContentFields.MSC3952_MENTIONS: {}, + EventContentFields.MENTIONS: {}, }, ) ) - @override_config({"experimental_features": {"msc3952_intentional_mentions": True}}) def test_room_mentions(self) -> None: """Test the behavior of an event which includes invalid room mentions.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -344,7 +335,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): # Room mentions from those without power should not notify. self.assertFalse( self._create_and_process( - bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}} ) ) @@ -358,7 +349,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) self.assertTrue( self._create_and_process( - bulk_evaluator, {EventContentFields.MSC3952_MENTIONS: {"room": True}} + bulk_evaluator, {EventContentFields.MENTIONS: {"room": True}} ) ) @@ -374,7 +365,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): self.assertFalse( self._create_and_process( bulk_evaluator, - {EventContentFields.MSC3952_MENTIONS: {"room": mentions}}, + {EventContentFields.MENTIONS: {"room": mentions}}, ) ) @@ -385,12 +376,11 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): { "body": "@room", "msgtype": "m.text", - EventContentFields.MSC3952_MENTIONS: {}, + EventContentFields.MENTIONS: {}, }, ) ) - @override_config({"experimental_features": {"msc3958_supress_edit_notifs": True}}) def test_suppress_edits(self) -> None: """Under the default push rules, event edits should not generate notifications.""" bulk_evaluator = BulkPushRuleEvaluator(self.hs) @@ -417,16 +407,33 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase): ) ) - # Room mentions from those without power should not notify. + # The edit should not cause a notification. self.assertFalse( self._create_and_process( bulk_evaluator, { - "body": self.alice, + "body": "Test message", + "m.relates_to": { + "rel_type": RelationTypes.REPLACE, + "event_id": event.event_id, + }, + }, + ) + ) + + # An edit which is a mention will cause a notification. + self.assertTrue( + self._create_and_process( + bulk_evaluator, + { + "body": "Test message", "m.relates_to": { "rel_type": RelationTypes.REPLACE, "event_id": event.event_id, }, + "m.mentions": { + "user_ids": [self.alice], + }, }, ) ) diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 4b5c96aeae..73a430ddc6 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -13,10 +13,12 @@ # limitations under the License. import email.message import os +from http import HTTPStatus from typing import Any, Dict, List, Sequence, Tuple import attr import pkg_resources +from parameterized import parameterized from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor @@ -25,9 +27,11 @@ import synapse.rest.admin from synapse.api.errors import Codes, SynapseError from synapse.push.emailpusher import EmailPusher from synapse.rest.client import login, room +from synapse.rest.synapse.client.unsubscribe import UnsubscribeResource from synapse.server import HomeServer from synapse.util import Clock +from tests.server import FakeSite, make_request from tests.unittest import HomeserverTestCase @@ -175,6 +179,57 @@ class EmailPusherTests(HomeserverTestCase): self._check_for_mail() + @parameterized.expand([(False,), (True,)]) + def test_unsubscribe(self, use_post: bool) -> None: + # Create a simple room with two users + room = self.helper.create_room_as(self.user_id, tok=self.access_token) + self.helper.invite( + room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id + ) + self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token) + + # The other user sends a single message. + self.helper.send(room, body="Hi!", tok=self.others[0].token) + + # We should get emailed about that message + args, kwargs = self._check_for_mail() + + # That email should contain an unsubscribe link in the body and header. + msg: bytes = args[5] + + # Multipart: plain text, base 64 encoded; html, base 64 encoded + multipart_msg = email.message_from_bytes(msg) + txt = multipart_msg.get_payload()[0].get_payload(decode=True).decode() + html = multipart_msg.get_payload()[1].get_payload(decode=True).decode() + self.assertIn("/_synapse/client/unsubscribe", txt) + self.assertIn("/_synapse/client/unsubscribe", html) + + # The unsubscribe headers should exist. + assert multipart_msg.get("List-Unsubscribe") is not None + self.assertIsNotNone(multipart_msg.get("List-Unsubscribe-Post")) + + # Open the unsubscribe link. + unsubscribe_link = multipart_msg["List-Unsubscribe"].strip("<>") + unsubscribe_resource = UnsubscribeResource(self.hs) + channel = make_request( + self.reactor, + FakeSite(unsubscribe_resource, self.reactor), + "POST" if use_post else "GET", + unsubscribe_link, + shorthand=False, + ) + self.assertEqual(HTTPStatus.OK, channel.code, channel.result) + + # Ensure the pusher was removed. + pushers = list( + self.get_success( + self.hs.get_datastores().main.get_pushers_by( + {"user_name": self.user_id} + ) + ) + ) + self.assertEqual(pushers, []) + def test_invite_sends_email(self) -> None: # Create a room and invite the user to it room = self.helper.create_room_as(self.others[0].id, tok=self.others[0].token) diff --git a/tests/replication/_base.py b/tests/replication/_base.py index eb9b1f1cd9..6712ac485d 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer +from synapse.config.workers import InstanceTcpLocationConfig, InstanceUnixLocationConfig from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler @@ -69,10 +70,10 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, + federation_http_client=None, ) # Since we use sqlite in memory databases we need to make sure the @@ -339,7 +340,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # `_handle_http_replication_attempt` like we do with the master HS. instance_name = worker_hs.get_instance_name() instance_loc = worker_hs.config.worker.instance_map.get(instance_name) - if instance_loc: + if instance_loc and isinstance(instance_loc, InstanceTcpLocationConfig): # Ensure the host is one that has a fake DNS entry. if instance_loc.host not in self.reactor.lookups: raise Exception( @@ -360,6 +361,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): instance_loc.port, lambda: self._handle_http_replication_attempt(worker_hs, port), ) + elif instance_loc and isinstance(instance_loc, InstanceUnixLocationConfig): + raise Exception( + "Unix sockets are not supported for unit tests at this time." + ) store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool @@ -380,6 +385,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): server_version_string="1", max_request_body_size=8192, reactor=self.reactor, + hs=worker_hs, ) worker_hs.get_replication_command_handler().start_replication(worker_hs) diff --git a/tests/replication/storage/_base.py b/tests/replication/storage/_base.py index de26a62ae1..afcc80a8b3 100644 --- a/tests/replication/storage/_base.py +++ b/tests/replication/storage/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Iterable, Optional +from typing import Any, Callable, Iterable, Optional from unittest.mock import Mock from twisted.test.proto_helpers import MemoryReactor @@ -47,24 +47,31 @@ class BaseWorkerStoreTestCase(BaseStreamTestCase): self.pump(0.1) def check( - self, method: str, args: Iterable[Any], expected_result: Optional[Any] = None + self, + method: str, + args: Iterable[Any], + expected_result: Optional[Any] = None, + asserter: Optional[Callable[[Any, Any, Optional[Any]], None]] = None, ) -> None: + if asserter is None: + asserter = self.assertEqual + master_result = self.get_success(getattr(self.master_store, method)(*args)) worker_result = self.get_success(getattr(self.worker_store, method)(*args)) if expected_result is not None: - self.assertEqual( + asserter( master_result, expected_result, "Expected master result to be %r but was %r" % (expected_result, master_result), ) - self.assertEqual( + asserter( worker_result, expected_result, "Expected worker result to be %r but was %r" % (expected_result, worker_result), ) - self.assertEqual( + asserter( master_result, worker_result, "Worker result %r does not match master result %r" diff --git a/tests/replication/storage/test_events.py b/tests/replication/storage/test_events.py index f7c6417a09..17716253f8 100644 --- a/tests/replication/storage/test_events.py +++ b/tests/replication/storage/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, Iterable, List, Optional, Tuple +from typing import Any, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json from parameterized import parameterized @@ -21,7 +21,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions -from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource from synapse.server import HomeServer @@ -46,32 +46,9 @@ ROOM_ID = "!room:test" logger = logging.getLogger(__name__) -def dict_equals(self: EventBase, other: EventBase) -> bool: - me = encode_canonical_json(self.get_pdu_json()) - them = encode_canonical_json(other.get_pdu_json()) - return me == them - - -def patch__eq__(cls: object) -> Callable[[], None]: - eq = getattr(cls, "__eq__", None) - cls.__eq__ = dict_equals # type: ignore[assignment] - - def unpatch() -> None: - if eq is not None: - cls.__eq__ = eq # type: ignore[assignment] - - return unpatch - - class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): STORE_TYPE = EventsWorkerStore - def setUp(self) -> None: - # Patch up the equality operator for events so that we can check - # whether lists of events match using assertEqual - self.unpatches = [patch__eq__(_EventInternalMetadata), patch__eq__(EventBase)] - super().setUp() - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) @@ -84,13 +61,19 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): ) ) - def tearDown(self) -> None: - [unpatch() for unpatch in self.unpatches] + def assertEventsEqual( + self, first: EventBase, second: EventBase, msg: Optional[Any] = None + ) -> None: + self.assertEqual( + encode_canonical_json(first.get_pdu_json()), + encode_canonical_json(second.get_pdu_json()), + msg, + ) def test_get_latest_event_ids_in_room(self) -> None: create = self.persist(type="m.room.create", key="", creator=USER_ID) self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id]) + self.check("get_latest_event_ids_in_room", (ROOM_ID,), {create.event_id}) join = self.persist( type="m.room.member", @@ -99,7 +82,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): prev_events=[(create.event_id, {})], ) self.replicate() - self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]) + self.check("get_latest_event_ids_in_room", (ROOM_ID,), {join.event_id}) def test_redactions(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) @@ -107,7 +90,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") self.replicate() - self.check("get_event", [msg.event_id], msg) + self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual) redaction = self.persist(type="m.room.redaction", redacts=msg.event_id) self.replicate() @@ -119,7 +102,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): redacted = make_event_from_dict( msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() ) - self.check("get_event", [msg.event_id], redacted) + self.check( + "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual + ) def test_backfilled_redactions(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) @@ -127,7 +112,7 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello") self.replicate() - self.check("get_event", [msg.event_id], msg) + self.check("get_event", [msg.event_id], msg, asserter=self.assertEventsEqual) redaction = self.persist( type="m.room.redaction", redacts=msg.event_id, backfill=True @@ -141,7 +126,9 @@ class EventsWorkerStoreTestCase(BaseWorkerStoreTestCase): redacted = make_event_from_dict( msg_dict, internal_metadata_dict=msg.internal_metadata.get_dict() ) - self.check("get_event", [msg.event_id], redacted) + self.check( + "get_event", [msg.event_id], redacted, asserter=self.assertEventsEqual + ) def test_invites(self) -> None: self.persist(type="m.room.create", key="", creator=USER_ID) diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 65ef4bb160..128fc3e046 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Optional, Sequence +from typing import Any, List, Optional from twisted.test.proto_helpers import MemoryReactor @@ -139,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -294,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: Sequence[str] = self.get_success( + fork_point = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -316,14 +316,14 @@ class EventsStreamTestCase(BaseStreamTestCase): self.test_handler.received_rdata_rows.clear() # now roll back all that state by de-modding the users - prev_events = fork_point + prev_events = list(fork_point) pl_events = [] for u in user_ids: pls["users"][u] = 0 e = self.get_success( inject_event( self.hs, - prev_event_ids=list(prev_events), + prev_event_ids=prev_events, type=EventTypes.PowerLevels, state_key="", sender=self.user_id, diff --git a/tests/replication/tcp/streams/test_to_device.py b/tests/replication/tcp/streams/test_to_device.py index fb9eac668f..ab379e8cf1 100644 --- a/tests/replication/tcp/streams/test_to_device.py +++ b/tests/replication/tcp/streams/test_to_device.py @@ -49,7 +49,7 @@ class ToDeviceStreamTestCase(BaseStreamTestCase): # add messages to the device inbox for user1 up until the # limit defined for a stream update batch - for i in range(0, _STREAM_UPDATE_TARGET_ROW_COUNT): + for i in range(_STREAM_UPDATE_TARGET_ROW_COUNT): msg["content"] = {"device": {}} messages = {user1: {"device": msg}} diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py index 12668b34c5..cf59b1a204 100644 --- a/tests/replication/test_federation_ack.py +++ b/tests/replication/test_federation_ack.py @@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase): config["worker_app"] = "synapse.app.generic_worker" config["worker_name"] = "federation_sender1" config["federation_sender_instances"] = ["federation_sender1"] + config["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}} return config def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 08703206a9..59f4fdc70b 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -12,17 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock + +from netaddr import IPSet from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.handlers.typing import TypingWriterHandler +from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.client import login, room from synapse.types import UserID, create_requester from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.test_utils import make_awaitable +from tests.server import get_clock logger = logging.getLogger(__name__) @@ -41,13 +44,25 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): room.register_servlets, ] + def setUp(self) -> None: + super().setUp() + + reactor, _ = get_clock() + self.matrix_federation_agent = MatrixFederationAgent( + reactor, + tls_client_options_factory=None, + user_agent=b"SynapseInTrialTest/0.0.0", + ip_allowlist=None, + ip_blocklist=IPSet(), + ) + def test_send_event_single_sender(self) -> None: """Test that using a single federation sender worker correctly sends a new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.return_value = make_awaitable({}) - + mock_client.put_json = AsyncMock(return_value={}) + mock_client.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", { @@ -77,7 +92,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.return_value = make_awaitable({}) + mock_client1.put_json = AsyncMock(return_value={}) + mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", { @@ -91,7 +107,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.return_value = make_awaitable({}) + mock_client2.put_json = AsyncMock(return_value={}) + mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", { @@ -144,7 +161,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.return_value = make_awaitable({}) + mock_client1.put_json = AsyncMock(return_value={}) + mock_client1.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", { @@ -158,7 +176,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.return_value = make_awaitable({}) + mock_client2.put_json = AsyncMock(return_value={}) + mock_client2.agent = self.matrix_federation_agent self.make_worker_hs( "synapse.app.generic_worker", { @@ -242,7 +261,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): builder = factory.for_room_version(room_version, event_dict) join_event = self.get_success( - builder.build(prev_event_ids=prev_event_ids, auth_event_ids=None) + builder.build(prev_event_ids=list(prev_event_ids), auth_event_ids=None) ) self.get_success(federation.on_send_membership_event(remote_server, join_event)) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 1527b4a82d..b230a6c361 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -13,7 +13,7 @@ # limitations under the License. import logging import os -from typing import Optional, Tuple +from typing import Any, Optional, Tuple from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory @@ -29,7 +29,7 @@ from synapse.util import Clock from tests.http import TestServerTLSConnectionFactory, get_test_ca_cert_file from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.server import FakeChannel, FakeSite, FakeTransport, make_request +from tests.server import FakeChannel, FakeTransport, make_request from tests.test_utils import SMALL_PNG logger = logging.getLogger(__name__) @@ -56,6 +56,16 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): conf["federation_custom_ca_list"] = [get_test_ca_cert_file()] return conf + def make_worker_hs( + self, worker_app: str, extra_config: Optional[dict] = None, **kwargs: Any + ) -> HomeServer: + worker_hs = super().make_worker_hs(worker_app, extra_config, **kwargs) + # Force the media paths onto the replication resource. + worker_hs.get_media_repository_resource().register_servlets( + self._hs_to_site[worker_hs].resource, worker_hs + ) + return worker_hs + def _get_media_req( self, hs: HomeServer, target: str, media_id: str ) -> Tuple[FakeChannel, Request]: @@ -68,12 +78,11 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): The channel for the *client* request and the *outbound* request for the media which the caller should respond to. """ - resource = hs.get_media_repository_resource().children[b"download"] channel = make_request( self.reactor, - FakeSite(resource, self.reactor), + self._hs_to_site[hs], "GET", - f"/{target}/{media_id}", + f"/_matrix/media/r0/download/{target}/{media_id}", shorthand=False, access_token=self.access_token, await_result=False, @@ -116,7 +125,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.assertEqual(request.method, b"GET") self.assertEqual( request.path, - f"/_matrix/media/r0/download/{target}/{media_id}".encode("utf-8"), + f"/_matrix/media/r0/download/{target}/{media_id}".encode(), ) self.assertEqual( request.requestHeaders.getRawHeaders(b"host"), [target.encode("utf-8")] diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index 695e84357a..8646b2f0fd 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -13,10 +13,12 @@ # limitations under the License. import urllib.parse +from typing import Dict from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.http.server import JsonResource @@ -26,7 +28,6 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG @@ -42,9 +43,7 @@ class VersionTestCase(unittest.HomeserverTestCase): channel = self.make_request("GET", self.url, shorthand=False) self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual( - {"server_version", "python_version"}, set(channel.json_body.keys()) - ) + self.assertEqual({"server_version"}, set(channel.json_body.keys())) class QuarantineMediaTestCase(unittest.HomeserverTestCase): @@ -57,21 +56,18 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): room.register_servlets, ] - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - # Allow for uploading and downloading to/from the media repo - self.media_repo = hs.get_media_repository_resource() - self.download_resource = self.media_repo.children[b"download"] - self.upload_resource = self.media_repo.children[b"upload"] + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources def _ensure_quarantined( self, admin_user_tok: str, server_and_media_id: str ) -> None: """Ensure a piece of media is quarantined when trying to access it.""" - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=admin_user_tok, ) @@ -119,20 +115,16 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("id_nonadmin", "pass") # Upload some media into the room - response = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=admin_user_tok - ) + response = self.helper.upload_media(SMALL_PNG, tok=admin_user_tok) # Extract media ID from the response server_name_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' server_name, media_id = server_name_and_media_id.split("/") # Attempt to access the media - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_name_and_media_id, + f"/_matrix/media/v3/download/{server_name_and_media_id}", shorthand=False, access_token=non_admin_user_tok, ) @@ -175,12 +167,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, non_admin_user, tok=non_admin_user_tok) # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract mxcs mxc_1 = response_1["content_uri"] @@ -229,12 +217,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("user_nonadmin", "pass") # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract media IDs server_and_media_id_1 = response_1["content_uri"][6:] @@ -267,12 +251,8 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): non_admin_user_tok = self.login("user_nonadmin", "pass") # Upload some media - response_1 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) - response_2 = self.helper.upload_media( - self.upload_resource, SMALL_PNG, tok=non_admin_user_tok - ) + response_1 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) + response_2 = self.helper.upload_media(SMALL_PNG, tok=non_admin_user_tok) # Extract media IDs server_and_media_id_1 = response_1["content_uri"][6:] @@ -306,11 +286,9 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self._ensure_quarantined(admin_user_tok, server_and_media_id_1) # Attempt to access each piece of media - channel = make_request( - self.reactor, - FakeSite(self.download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id_2, + f"/_matrix/media/v3/download/{server_and_media_id_2}", shorthand=False, access_token=non_admin_user_tok, ) diff --git a/tests/rest/admin/test_federation.py b/tests/rest/admin/test_federation.py index 4c7864c629..0e2824d1b5 100644 --- a/tests/rest/admin/test_federation.py +++ b/tests/rest/admin/test_federation.py @@ -510,7 +510,7 @@ class FederationTestCase(unittest.HomeserverTestCase): Args: number_destinations: Number of destinations to be created """ - for i in range(0, number_destinations): + for i in range(number_destinations): dest = f"sub{i}.example.com" self._create_destination(dest, 50, 50, 50, 100) @@ -690,7 +690,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): self._check_fields(channel_desc.json_body["rooms"]) # test that both lists have different directions - for i in range(0, number_rooms): + for i in range(number_rooms): self.assertEqual( channel_asc.json_body["rooms"][i]["room_id"], channel_desc.json_body["rooms"][number_rooms - 1 - i]["room_id"], @@ -777,7 +777,7 @@ class DestinationMembershipTestCase(unittest.HomeserverTestCase): Args: number_rooms: Number of rooms to be created """ - for _ in range(0, number_rooms): + for _ in range(number_rooms): room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok ) diff --git a/tests/rest/admin/test_jwks.py b/tests/rest/admin/test_jwks.py new file mode 100644 index 0000000000..a9a6191c73 --- /dev/null +++ b/tests/rest/admin/test_jwks.py @@ -0,0 +1,106 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client import build_synapse_client_resource_tree + +from tests.unittest import HomeserverTestCase, override_config, skip_unless + +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + + +@skip_unless(HAS_AUTHLIB, "requires authlib") +class JWKSTestCase(HomeserverTestCase): + """Test /_synapse/jwks JWKS data.""" + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d.update(build_synapse_client_resource_tree(self.hs)) + return d + + def test_empty_jwks(self) -> None: + """Test that the JWKS endpoint is not present by default.""" + channel = self.make_request("GET", "/_synapse/jwks") + self.assertEqual(404, channel.code, channel.result) + + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc3861": { + "enabled": True, + "issuer": "https://issuer/", + "client_id": "test-client-id", + "client_auth_method": "client_secret_post", + "client_secret": "secret", + }, + }, + } + ) + def test_empty_jwks_for_msc3861_client_secret_post(self) -> None: + """Test that the JWKS endpoint is empty when plain auth is used.""" + channel = self.make_request("GET", "/_synapse/jwks") + self.assertEqual(200, channel.code, channel.result) + self.assertEqual({"keys": []}, channel.json_body) + + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc3861": { + "enabled": True, + "issuer": "https://issuer/", + "client_id": "test-client-id", + "client_auth_method": "private_key_jwt", + "jwk": { + "p": "-frVdP_tZ-J_nIR6HNMDq1N7aunwm51nAqNnhqIyuA8ikx7LlQED1tt2LD3YEvYyW8nxE2V95HlCRZXQPMiRJBFOsbmYkzl2t-MpavTaObB_fct_JqcRtdXddg4-_ihdjRDwUOreq_dpWh6MIKsC3UyekfkHmeEJg5YpOTL15j8", + "kty": "RSA", + "q": "oFw-Enr_YozQB1ab-kawn4jY3yHi8B1nSmYT0s8oTCflrmps5BFJfCkHL5ij3iY15z0o2m0N-jjB1oSJ98O4RayEEYNQlHnTNTl0kRIWzpoqblHUIxVcahIpP_xTovBJzwi8XXoLGqHOOMA-r40LSyVgP2Ut8D9qBwV6_UfT0LU", + "d": "WFkDPYo4b4LIS64D_QtQfGGuAObPvc3HFfp9VZXyq3SJR58XZRHE0jqtlEMNHhOTgbMYS3w8nxPQ_qVzY-5hs4fIanwvB64mAoOGl0qMHO65DTD_WsGFwzYClJPBVniavkLE2Hmpu8IGe6lGliN8vREC6_4t69liY-XcN_ECboVtC2behKkLOEASOIMuS7YcKAhTJFJwkl1dqDlliEn5A4u4xy7nuWQz3juB1OFdKlwGA5dfhDNglhoLIwNnkLsUPPFO-WB5ZNEW35xxHOToxj4bShvDuanVA6mJPtTKjz0XibjB36bj_nF_j7EtbE2PdGJ2KevAVgElR4lqS4ISgQ", + "e": "AQAB", + "kid": "test", + "qi": "cPfNk8l8W5exVNNea4d7QZZ8Qr8LgHghypYAxz8PQh1fNa8Ya1SNUDVzC2iHHhszxxA0vB9C7jGze8dBrvnzWYF1XvQcqNIVVgHhD57R1Nm3dj2NoHIKe0Cu4bCUtP8xnZQUN4KX7y4IIcgRcBWG1hT6DEYZ4BxqicnBXXNXAUI", + "dp": "dKlMHvslV1sMBQaKWpNb3gPq0B13TZhqr3-E2_8sPlvJ3fD8P4CmwwnOn50JDuhY3h9jY5L06sBwXjspYISVv8hX-ndMLkEeF3lrJeA5S70D8rgakfZcPIkffm3tlf1Ok3v5OzoxSv3-67Df4osMniyYwDUBCB5Oq1tTx77xpU8", + "dq": "S4ooU1xNYYcjl9FcuJEEMqKsRrAXzzSKq6laPTwIp5dDwt2vXeAm1a4eDHXC-6rUSZGt5PbqVqzV4s-cjnJMI8YYkIdjNg4NSE1Ac_YpeDl3M3Colb5CQlU7yUB7xY2bt0NOOFp9UJZYJrOo09mFMGjy5eorsbitoZEbVqS3SuE", + "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw", + }, + }, + }, + } + ) + def test_key_returned_for_msc3861_client_secret_post(self) -> None: + """Test that the JWKS includes public part of JWK for private_key_jwt auth is used.""" + channel = self.make_request("GET", "/_synapse/jwks") + self.assertEqual(200, channel.code, channel.result) + self.assertEqual( + { + "keys": [ + { + "kty": "RSA", + "e": "AQAB", + "kid": "test", + "n": "nJbYKqFwnURKimaviyDFrNLD3gaKR1JW343Qem25VeZxoMq1665RHVoO8n1oBm4ClZdjIiZiVdpyqzD5-Ow12YQgQEf1ZHP3CCcOQQhU57Rh5XvScTe5IxYVkEW32IW2mp_CJ6WfjYpfeL4azarVk8H3Vr59d1rSrKTVVinVdZer9YLQyC_rWAQNtHafPBMrf6RYiNGV9EiYn72wFIXlLlBYQ9Fx7bfe1PaL6qrQSsZP3_rSpuvVdLh1lqGeCLR0pyclA9uo5m2tMyCXuuGQLbA_QJm5xEc7zd-WFdux2eXF045oxnSZ_kgQt-pdN7AxGWOVvwoTf9am6mSkEdv6iw", + } + ] + }, + channel.json_body, + ) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6d04911d67..278808abb5 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -13,10 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Dict from parameterized import parameterized from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.errors import Codes @@ -26,22 +28,27 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request from tests.test_utils import SMALL_PNG VALID_TIMESTAMP = 1609459200000 # 2021-01-01 in milliseconds INVALID_TIMESTAMP_IN_S = 1893456000 # 2030-01-01 in seconds -class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): +class _AdminMediaTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, login.register_servlets, ] + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + + +class DeleteMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) @@ -117,12 +124,8 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): Tests that delete a media is successfully """ - download_resource = self.media_repo.children[b"download"] - upload_resource = self.media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -134,11 +137,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertEqual(server_name, self.server_name) # Attempt to access media - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -173,11 +174,9 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): ) # Attempt to access media - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -194,7 +193,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) -class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): +class DeleteMediaByDateSizeTestCase(_AdminMediaTests): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, @@ -529,11 +528,8 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ Create a media and return media_id and server_and_media_id """ - upload_resource = self.media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -553,16 +549,12 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): """ Try to access a media and check the result """ - download_resource = self.media_repo.children[b"download"] - media_id = server_and_media_id.split("/")[1] local_path = self.filepaths.local_media_filepath(media_id) - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=self.admin_user_tok, ) @@ -591,27 +583,16 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) -class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - synapse.rest.admin.register_servlets_for_media_repo, - login.register_servlets, - ] - +class QuarantineMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo = hs.get_media_repository_resource() self.store = hs.get_datastores().main self.server_name = hs.hostname self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - # Create media - upload_resource = media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -720,26 +701,16 @@ class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(media_info["quarantined_by"]) -class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - synapse.rest.admin.register_servlets_for_media_repo, - login.register_servlets, - ] - +class ProtectMediaByIDTestCase(_AdminMediaTests): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - media_repo = hs.get_media_repository_resource() + hs.get_media_repository_resource() self.store = hs.get_datastores().main self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - # Create media - upload_resource = media_repo.children[b"upload"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, SMALL_PNG, tok=self.admin_user_tok, expect_code=200, @@ -816,7 +787,7 @@ class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): self.assertFalse(media_info["safe_from_quarantine"]) -class PurgeMediaCacheTestCase(unittest.HomeserverTestCase): +class PurgeMediaCacheTestCase(_AdminMediaTests): servlets = [ synapse.rest.admin.register_servlets, synapse.rest.admin.register_servlets_for_media_repo, diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index eb50086c50..6ed451d7c4 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -15,26 +15,34 @@ import json import time import urllib.parse from typing import List, Optional -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from parameterized import parameterized +from twisted.internet.task import deferLater from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import EventTypes, Membership, RoomTypes from synapse.api.errors import Codes -from synapse.handlers.pagination import PaginationHandler, PurgeStatus +from synapse.handlers.pagination import ( + PURGE_ROOM_ACTION_NAME, + SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME, +) from synapse.rest.client import directory, events, login, room from synapse.server import HomeServer +from synapse.types import UserID from synapse.util import Clock -from synapse.util.stringutils import random_string +from synapse.util.task_scheduler import TaskScheduler from tests import unittest """Tests admin REST events for /rooms paths.""" +ONE_HOUR_IN_S = 3600 + + class DeleteRoomTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -46,6 +54,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() + self.task_scheduler = hs.get_task_scheduler() hs.config.consent.user_consent_version = "1" consent_uri_builder = Mock() @@ -476,6 +485,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.event_creation_handler = hs.get_event_creation_handler() + self.task_scheduler = hs.get_task_scheduler() hs.config.consent.user_consent_version = "1" consent_uri_builder = Mock() @@ -502,6 +512,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): ) self.url_status_by_delete_id = "/_synapse/admin/v2/rooms/delete_status/" + self.room_member_handler = hs.get_room_member_handler() + self.pagination_handler = hs.get_pagination_handler() + @parameterized.expand( [ ("DELETE", "/_synapse/admin/v2/rooms/%s"), @@ -661,7 +674,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): delete_id1 = channel.json_body["delete_id"] # go ahead - self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2) # second task channel = self.make_request( @@ -686,12 +699,14 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(2, len(channel.json_body["results"])) self.assertEqual("complete", channel.json_body["results"][0]["status"]) self.assertEqual("complete", channel.json_body["results"][1]["status"]) - self.assertEqual(delete_id1, channel.json_body["results"][0]["delete_id"]) - self.assertEqual(delete_id2, channel.json_body["results"][1]["delete_id"]) + delete_ids = {delete_id1, delete_id2} + self.assertTrue(channel.json_body["results"][0]["delete_id"] in delete_ids) + delete_ids.remove(channel.json_body["results"][0]["delete_id"]) + self.assertTrue(channel.json_body["results"][1]["delete_id"] in delete_ids) # get status after more than clearing time for first task # second task is not cleared - self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2) channel = self.make_request( "GET", @@ -705,7 +720,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(delete_id2, channel.json_body["results"][0]["delete_id"]) # get status after more than clearing time for all tasks - self.reactor.advance(PaginationHandler.CLEAR_PURGE_AFTER_MS / 1000 / 2) + self.reactor.advance(TaskScheduler.KEEP_TASKS_FOR_MS / 1000 / 2) channel = self.make_request( "GET", @@ -721,6 +736,13 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): body = {"new_room_user_id": self.admin_user} + # Mock PaginationHandler.purge_room to sleep for 100s, so we have time to do a second call + # before the purge is over. Note that it doesn't purge anymore, but we don't care. + async def purge_room(room_id: str, force: bool) -> None: + await deferLater(self.hs.get_reactor(), 100, lambda: None) + + self.pagination_handler.purge_room = AsyncMock(side_effect=purge_room) # type: ignore[method-assign] + # first call to delete room # and do not wait for finish the task first_channel = self.make_request( @@ -728,7 +750,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.url.encode("ascii"), content=body, access_token=self.admin_user_tok, - await_result=False, ) # second call to delete room @@ -742,7 +763,7 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(400, second_channel.code, msg=second_channel.json_body) self.assertEqual(Codes.UNKNOWN, second_channel.json_body["errcode"]) self.assertEqual( - f"History purge already in progress for {self.room_id}", + f"Purge already in progress for {self.room_id}", second_channel.json_body["error"], ) @@ -751,6 +772,9 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): self.assertEqual(200, first_channel.code, msg=first_channel.json_body) self.assertIn("delete_id", first_channel.json_body) + # wait for purge_room to finish + self.pump(1) + # check status after finish the task self._test_result( first_channel.json_body["delete_id"], @@ -972,6 +996,115 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): # Assert we can no longer peek into the room self._assert_peek(self.room_id, expect_code=403) + @unittest.override_config({"forgotten_room_retention_period": "1d"}) + def test_purge_forgotten_room(self) -> None: + # Create a test room + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + ) + + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + self.get_success( + self.room_member_handler.forget( + UserID.from_string(self.admin_user), room_id + ) + ) + + # Test that room is not yet purged + with self.assertRaises(AssertionError): + self._is_purged(room_id) + + # Advance 24 hours in the future, past the `forgotten_room_retention_period` + self.reactor.advance(24 * ONE_HOUR_IN_S) + + self._is_purged(room_id) + + def test_scheduled_purge_room(self) -> None: + # Create a test room + room_id = self.helper.create_room_as( + self.admin_user, + tok=self.admin_user_tok, + ) + self.helper.leave(room_id, user=self.admin_user, tok=self.admin_user_tok) + + # Schedule a purge 10 seconds in the future + self.get_success( + self.task_scheduler.schedule_task( + PURGE_ROOM_ACTION_NAME, + resource_id=room_id, + timestamp=self.clock.time_msec() + 10 * 1000, + ) + ) + + # Test that room is not yet purged + with self.assertRaises(AssertionError): + self._is_purged(room_id) + + # Wait for next scheduler run + self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS) + + self._is_purged(room_id) + + def test_schedule_shutdown_room(self) -> None: + # Create a test room + room_id = self.helper.create_room_as( + self.other_user, + tok=self.other_user_tok, + ) + + # Schedule a shutdown 10 seconds in the future + delete_id = self.get_success( + self.task_scheduler.schedule_task( + SHUTDOWN_AND_PURGE_ROOM_ACTION_NAME, + resource_id=room_id, + params={ + "requester_user_id": self.admin_user, + "new_room_user_id": self.admin_user, + "new_room_name": None, + "message": None, + "block": False, + "purge": True, + "force_purge": True, + }, + timestamp=self.clock.time_msec() + 10 * 1000, + ) + ) + + # Test that room is not yet shutdown + self._is_member(room_id, self.other_user) + + # Test that room is not yet purged + with self.assertRaises(AssertionError): + self._is_purged(room_id) + + # Wait for next scheduler run + self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS) + + # Test that all users has been kicked (room is shutdown) + self._has_no_members(room_id) + + self._is_purged(room_id) + + # Retrieve delete results + result = self.make_request( + "GET", + self.url_status_by_delete_id + delete_id, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, result.code, msg=result.json_body) + + # Check that the user is in kicked_users + self.assertIn( + self.other_user, result.json_body["shutdown_room"]["kicked_users"] + ) + + new_room_id = result.json_body["shutdown_room"]["new_room_id"] + self.assertTrue(new_room_id) + + # Check that the user is actually in the new room + self._is_member(new_room_id, self.other_user) + def _is_blocked(self, room_id: str, expect: bool = True) -> None: """Assert that the room is blocked or not""" d = self.store.is_room_blocked(room_id) @@ -1034,7 +1167,6 @@ class DeleteRoomV2TestCase(unittest.HomeserverTestCase): kicked_user: a user_id which is kicked from the room expect_new_room: if we expect that a new room was created """ - # get information by room_id channel_room_id = self.make_request( "GET", @@ -1957,11 +2089,8 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase): self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) # Purge every event before the second event. - purge_id = random_string(16) - pagination_handler._purges_by_id[purge_id] = PurgeStatus() self.get_success( - pagination_handler._purge_history( - purge_id=purge_id, + pagination_handler.purge_history( room_id=self.room_id, token=second_token_str, delete_local_events=True, diff --git a/tests/rest/admin/test_server_notice.py b/tests/rest/admin/test_server_notice.py index 28b999573e..dfd14f5751 100644 --- a/tests/rest/admin/test_server_notice.py +++ b/tests/rest/admin/test_server_notice.py @@ -22,6 +22,7 @@ from synapse.server import HomeServer from synapse.storage.roommember import RoomsForUser from synapse.types import JsonDict from synapse.util import Clock +from synapse.util.stringutils import random_string from tests import unittest from tests.unittest import override_config @@ -413,11 +414,24 @@ class ServerNoticeTestCase(unittest.HomeserverTestCase): self.assertEqual(messages[0]["content"]["body"], "test msg one") self.assertEqual(messages[0]["sender"], "@notices:test") + random_string(16) + # shut down and purge room self.get_success( - self.room_shutdown_handler.shutdown_room(first_room_id, self.admin_user) - ) - self.get_success(self.pagination_handler.purge_room(first_room_id)) + self.room_shutdown_handler.shutdown_room( + first_room_id, + { + "requester_user_id": self.admin_user, + "new_room_user_id": None, + "new_room_name": None, + "message": None, + "block": False, + "purge": True, + "force_purge": False, + }, + ) + ) + self.get_success(self.pagination_handler.purge_room(first_room_id, force=False)) # user is not member anymore self._check_invite_and_join_status(self.other_user, 0, 0) diff --git a/tests/rest/admin/test_statistics.py b/tests/rest/admin/test_statistics.py index b60f16b914..cd8ee274d8 100644 --- a/tests/rest/admin/test_statistics.py +++ b/tests/rest/admin/test_statistics.py @@ -12,9 +12,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Dict, List, Optional from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.errors import Codes @@ -34,8 +35,6 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.media_repo = hs.get_media_repository_resource() - self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") @@ -44,6 +43,11 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): self.url = "/_synapse/admin/v1/statistics/users/media" + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + def test_no_auth(self) -> None: """ Try to list users without authentication. @@ -470,12 +474,9 @@ class UserMediaStatisticsTestCase(unittest.HomeserverTestCase): user_token: Access token of the user number_media: Number of media to be created for the user """ - upload_resource = self.media_repo.children[b"upload"] for _ in range(number_media): # Upload some media into the room - self.helper.upload_media( - upload_resource, SMALL_PNG, tok=user_token, expect_code=200 - ) + self.helper.upload_media(SMALL_PNG, tok=user_token, expect_code=200) def _check_fields(self, content: List[JsonDict]) -> None: """Checks that all attributes are present in content diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 434bb56d44..37f37a09d8 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -17,26 +17,36 @@ import hmac import os import urllib.parse from binascii import unhexlify -from typing import List, Optional -from unittest.mock import Mock, patch +from typing import Dict, List, Optional +from unittest.mock import AsyncMock, Mock, patch from parameterized import parameterized, parameterized_class from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType, UserTypes from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.room_versions import RoomVersions from synapse.media.filepath import MediaFilePaths -from synapse.rest.client import devices, login, logout, profile, register, room, sync +from synapse.rest.client import ( + devices, + login, + logout, + profile, + register, + room, + sync, + user_directory, +) from synapse.server import HomeServer +from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest -from tests.server import FakeSite, make_request -from tests.test_utils import SMALL_PNG, make_awaitable +from tests.test_utils import SMALL_PNG from tests.unittest import override_config @@ -62,8 +72,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs.config.registration.registration_shared_secret = "shared" - self.hs.get_media_repository = Mock() # type: ignore[assignment] - self.hs.get_deactivate_account_handler = Mock() # type: ignore[assignment] + self.hs.get_media_repository = Mock() # type: ignore[method-assign] + self.hs.get_deactivate_account_handler = Mock() # type: ignore[method-assign] return self.hs @@ -410,8 +420,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): store = self.hs.get_datastores().main # Set monthly active users to the limit - store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -447,6 +457,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, login.register_servlets, + room.register_servlets, ] url = "/_synapse/admin/v2/users" @@ -497,6 +508,62 @@ class UsersListTestCase(unittest.HomeserverTestCase): # Check that all fields are available self._check_fields(channel.json_body["users"]) + def test_last_seen(self) -> None: + """ + Test that last_seen_ts field is properly working. + """ + user1 = self.register_user("u1", "pass") + user1_token = self.login("u1", "pass") + user2 = self.register_user("u2", "pass") + user2_token = self.login("u2", "pass") + user3 = self.register_user("u3", "pass") + user3_token = self.login("u3", "pass") + + self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + self.reactor.advance(10) + self.helper.create_room_as(user2, tok=user2_token) + self.reactor.advance(10) + self.helper.create_room_as(user1, tok=user1_token) + self.reactor.advance(10) + self.helper.create_room_as(user3, tok=user3_token) + self.reactor.advance(10) + + channel = self.make_request( + "GET", + self.url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertEqual(4, len(channel.json_body["users"])) + self.assertEqual(4, channel.json_body["total"]) + + admin_last_seen = channel.json_body["users"][0]["last_seen_ts"] + user1_last_seen = channel.json_body["users"][1]["last_seen_ts"] + user2_last_seen = channel.json_body["users"][2]["last_seen_ts"] + user3_last_seen = channel.json_body["users"][3]["last_seen_ts"] + self.assertTrue(admin_last_seen > 0 and admin_last_seen < 10000) + self.assertTrue(user2_last_seen > 10000 and user2_last_seen < 20000) + self.assertTrue(user1_last_seen > 20000 and user1_last_seen < 30000) + self.assertTrue(user3_last_seen > 30000 and user3_last_seen < 40000) + + self._order_test([self.admin_user, user2, user1, user3], "last_seen_ts") + + self.reactor.advance(LAST_SEEN_GRANULARITY / 1000) + self.helper.create_room_as(user1, tok=user1_token) + self.reactor.advance(10) + + channel = self.make_request( + "GET", + self.url + "/" + user1, + access_token=self.admin_user_tok, + ) + self.assertTrue( + channel.json_body["last_seen_ts"] > 40000 + LAST_SEEN_GRANULARITY + ) + + self._order_test([self.admin_user, user2, user3, user1], "last_seen_ts") + def test_search_term(self) -> None: """Test that searching for a users works correctly""" @@ -870,6 +937,44 @@ class UsersListTestCase(unittest.HomeserverTestCase): self._order_test([self.admin_user, user1, user2], "creation_ts", "f") self._order_test([user2, user1, self.admin_user], "creation_ts", "b") + def test_filter_admins(self) -> None: + """ + Tests whether the various values of the query parameter `admins` lead to the + expected result set. + """ + + # Register an additional non admin user + self.register_user("user", "pass", admin=False) + + # Query all users + channel = self.make_request( + "GET", + f"{self.url}", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(2, channel.json_body["total"]) + + # Query only admin users + channel = self.make_request( + "GET", + f"{self.url}?admins=true", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(1, channel.json_body["total"]) + self.assertEqual(1, channel.json_body["users"][0]["admin"]) + + # Query only non admin users + channel = self.make_request( + "GET", + f"{self.url}?admins=false", + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, channel.result) + self.assertEqual(1, channel.json_body["total"]) + self.assertFalse(channel.json_body["users"][0]["admin"]) + @override_config( { "experimental_features": { @@ -933,6 +1038,84 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids) self.assertEqual(not_approved_user, non_admin_user_ids[0]) + def test_filter_not_user_types(self) -> None: + """Tests that the endpoint handles the not_user_types param""" + + regular_user_id = self.register_user("normalo", "secret") + + bot_user_id = self.register_user("robo", "secret") + self.make_request( + "PUT", + "/_synapse/admin/v2/users/" + urllib.parse.quote(bot_user_id), + {"user_type": UserTypes.BOT}, + access_token=self.admin_user_tok, + ) + + support_user_id = self.register_user("foo", "secret") + self.make_request( + "PUT", + "/_synapse/admin/v2/users/" + urllib.parse.quote(support_user_id), + {"user_type": UserTypes.SUPPORT}, + access_token=self.admin_user_tok, + ) + + def test_user_type( + expected_user_ids: List[str], not_user_types: Optional[List[str]] = None + ) -> None: + """Runs a test for the not_user_types param + Args: + expected_user_ids: Ids of the users that are expected to be returned + not_user_types: List of values for the not_user_types param + """ + + user_type_query = "" + + if not_user_types is not None: + user_type_query = "&".join( + [f"not_user_type={u}" for u in not_user_types] + ) + + test_url = f"{self.url}?{user_type_query}" + channel = self.make_request( + "GET", + test_url, + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code) + self.assertEqual(channel.json_body["total"], len(expected_user_ids)) + self.assertEqual( + expected_user_ids, + [u["name"] for u in channel.json_body["users"]], + ) + + # Request without user_types → all users expected + test_user_type([self.admin_user, support_user_id, regular_user_id, bot_user_id]) + + # Request and exclude bot users + test_user_type( + [self.admin_user, support_user_id, regular_user_id], + not_user_types=[UserTypes.BOT], + ) + + # Request and exclude bot and support users + test_user_type( + [self.admin_user, regular_user_id], + not_user_types=[UserTypes.BOT, UserTypes.SUPPORT], + ) + + # Request and exclude empty user types → only expected the bot and support user + test_user_type([support_user_id, bot_user_id], not_user_types=[""]) + + # Request and exclude empty user types and bots → only expected the support user + test_user_type([support_user_id], not_user_types=["", UserTypes.BOT]) + + # Request and exclude a custom type (neither service nor bot) → expect all users + test_user_type( + [self.admin_user, support_user_id, regular_user_id, bot_user_id], + not_user_types=["custom"], + ) + def test_erasure_status(self) -> None: # Create a new user. user_id = self.register_user("eraseme", "eraseme") @@ -963,6 +1146,32 @@ class UsersListTestCase(unittest.HomeserverTestCase): users = {user["name"]: user for user in channel.json_body["users"]} self.assertIs(users[user_id]["erased"], True) + def test_filter_locked(self) -> None: + # Create a new user. + user_id = self.register_user("lockme", "lockme") + + # Lock them + self.get_success(self.store.set_user_locked_status(user_id, True)) + + # Locked user should appear in list users API + channel = self.make_request( + "GET", + self.url + "?locked=true", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertIn(user_id, users) + self.assertTrue(users[user_id]["locked"]) + + # Locked user should not appear in list users API + channel = self.make_request( + "GET", + self.url + "?locked=false", + access_token=self.admin_user_tok, + ) + users = {user["name"]: user for user in channel.json_body["users"]} + self.assertNotIn(user_id, users) + def _order_test( self, expected_user_list: List[str], @@ -1010,6 +1219,7 @@ class UsersListTestCase(unittest.HomeserverTestCase): self.assertIn("displayname", u) self.assertIn("avatar_url", u) self.assertIn("creation_ts", u) + self.assertIn("last_seen_ts", u) def _create_users(self, number_users: int) -> None: """ @@ -1340,7 +1550,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): # To test deactivation for users without a profile, we delete the profile information for our user. self.get_success( self.store.db_pool.simple_delete_one( - table="profiles", keyvalues={"user_id": "user"} + table="profiles", keyvalues={"full_user_id": "@user:test"} ) ) @@ -1399,6 +1609,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): login.register_servlets, sync.register_servlets, register.register_servlets, + user_directory.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -1708,8 +1919,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): ) # Set monthly active users to the limit - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1745,8 +1956,8 @@ class UserRestTestCase(unittest.HomeserverTestCase): handler = self.hs.get_registration_handler() # Set monthly active users to the limit - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.server.max_mau_value) + self.store.get_monthly_active_count = AsyncMock( + return_value=self.hs.config.server.max_mau_value ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -2386,6 +2597,105 @@ class UserRestTestCase(unittest.HomeserverTestCase): # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", channel.json_body) + def test_locked_user(self) -> None: + # User can sync + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is not authorized to sync anymore + channel = self.make_request( + "GET", + "/_matrix/client/v3/sync", + access_token=self.other_user_token, + ) + self.assertEqual(401, channel.code, msg=channel.json_body) + self.assertEqual(Codes.USER_LOCKED, channel.json_body["errcode"]) + self.assertTrue(channel.json_body["soft_logout"]) + + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) + def test_locked_user_not_in_user_dir(self) -> None: + # User is available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is not available anymore in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(0, len(channel.json_body["results"])) + + @override_config( + { + "user_directory": { + "enabled": True, + "search_all_users": True, + "show_locked_users": True, + } + } + ) + def test_locked_user_in_user_dir_with_show_locked_users_option(self) -> None: + # User is available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + + # Lock user + channel = self.make_request( + "PUT", + self.url_other_user, + access_token=self.admin_user_tok, + content={"locked": True}, + ) + + # User is still available in the user dir + channel = self.make_request( + "POST", + "/_matrix/client/v3/user_directory/search", + {"search_term": self.other_user}, + access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("results", channel.json_body) + self.assertEqual(1, len(channel.json_body["results"])) + @override_config({"user_directory": {"enabled": True, "search_all_users": True}}) def test_change_name_deactivate_user_user_directory(self) -> None: """ @@ -2394,7 +2704,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): """ # is in user directory - profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + profile = self.get_success(self.store._get_user_in_directory(self.other_user)) assert profile is not None self.assertTrue(profile["display_name"] == "User") @@ -2411,7 +2721,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertTrue(channel.json_body["deactivated"]) # is not in user directory - profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + profile = self.get_success(self.store._get_user_in_directory(self.other_user)) self.assertIsNone(profile) # Set new displayname user @@ -2428,7 +2738,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertEqual("Foobar", channel.json_body["displayname"]) # is not in user directory - profile = self.get_success(self.store.get_user_in_directory(self.other_user)) + profile = self.get_success(self.store._get_user_in_directory(self.other_user)) self.assertIsNone(profile) def test_reactivate_user(self) -> None: @@ -2810,6 +3120,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): self.assertIn("consent_version", content) self.assertIn("consent_ts", content) self.assertIn("external_ids", content) + self.assertIn("last_seen_ts", content) # This key was removed intentionally. Ensure it is not accidentally re-included. self.assertNotIn("password_hash", content) @@ -3110,7 +3421,6 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main - self.media_repo = hs.get_media_repository_resource() self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) @@ -3121,6 +3431,11 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): self.other_user ) + def create_resource_dict(self) -> Dict[str, Resource]: + resources = super().create_resource_dict() + resources["/_matrix/media"] = self.hs.get_media_repository_resource() + return resources + @parameterized.expand(["GET", "DELETE"]) def test_no_auth(self, method: str) -> None: """Try to list media of an user without authentication.""" @@ -3596,12 +3911,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): Returns: The ID of the newly created media. """ - upload_resource = self.media_repo.children[b"upload"] - download_resource = self.media_repo.children[b"download"] - # Upload some media into the room response = self.helper.upload_media( - upload_resource, image_data, user_token, filename, expect_code=200 + image_data, user_token, filename, expect_code=200 ) # Extract media ID from the response @@ -3609,11 +3921,9 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): media_id = server_and_media_id.split("/")[1] # Try to access a media and to create `last_access_ts` - channel = make_request( - self.reactor, - FakeSite(download_resource, self.reactor), + channel = self.make_request( "GET", - server_and_media_id, + f"/_matrix/media/v3/download/{server_and_media_id}", shorthand=False, access_token=user_token, ) diff --git a/tests/rest/admin/test_username_available.py b/tests/rest/admin/test_username_available.py index 6c04e6c56c..4c69d224b8 100644 --- a/tests/rest/admin/test_username_available.py +++ b/tests/rest/admin/test_username_available.py @@ -50,7 +50,7 @@ class UsernameAvailableTestCase(unittest.HomeserverTestCase): ) handler = self.hs.get_registration_handler() - handler.check_username = check_username # type: ignore[assignment] + handler.check_username = check_username # type: ignore[method-assign] def test_username_available(self) -> None: """ diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index ac19f3c6da..cffbda9a7d 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -31,6 +31,7 @@ from synapse.rest import admin from synapse.rest.client import account, login, register, room from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource from synapse.server import HomeServer +from synapse.storage._base import db_to_json from synapse.types import JsonDict, UserID from synapse.util import Clock @@ -134,6 +135,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + # Check that the UI Auth information doesn't store the password in the database. + # + # Note that we don't have the UI Auth session ID, so just pull out the single + # row. + ui_auth_data = self.get_success( + self.store.db_pool.simple_select_one( + "ui_auth_sessions", keyvalues={}, retcols=("clientdict",) + ) + ) + client_dict = db_to_json(ui_auth_data["clientdict"]) + self.assertNotIn("new_password", client_dict) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) def test_ratelimit_by_email(self) -> None: """Test that we ratelimit /requestToken for the same email.""" @@ -562,7 +575,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase): # create a bunch of users and add keys for them users = [] - for i in range(0, 20): + for i in range(20): user_id = self.register_user("missPiggy" + str(i), "test") users.append((user_id,)) @@ -1346,7 +1359,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase): return {} # Register a mock that will return the expected result depending on the remote. - self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[assignment] + self.hs.get_federation_http_client().post_json = Mock(side_effect=post_json) # type: ignore[method-assign] # Check that we've got the correct response from the client-side endpoint. self._test_status( diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py index d5b0640e7a..481db9a687 100644 --- a/tests/rest/client/test_account_data.py +++ b/tests/rest/client/test_account_data.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock from synapse.rest import admin from synapse.rest.client import account_data, login, room from tests import unittest -from tests.test_utils import make_awaitable class AccountDataTestCase(unittest.HomeserverTestCase): @@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase): """Tests that the on_account_data_updated module callback is called correctly when a user's account data changes. """ - mocked_callback = Mock(return_value=make_awaitable(None)) + mocked_callback = AsyncMock(return_value=None) self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append( mocked_callback ) diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index c16e8d43f4..cf23430f6a 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -186,3 +186,31 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertGreater(len(details["support"]), 0) for room_version in details["support"]: self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, str(room_version)) + + def test_get_get_token_login_fields_when_disabled(self) -> None: + """By default login via an existing session is disabled.""" + access_token = self.get_success( + self.auth_handler.create_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertFalse(capabilities["m.get_login_token"]["enabled"]) + + @override_config({"login_via_existing_session": {"enabled": True}}) + def test_get_get_token_login_fields_when_enabled(self) -> None: + access_token = self.get_success( + self.auth_handler.create_access_token_for_user_id( + self.user, device_id=None, valid_until_ms=None + ) + ) + + channel = self.make_request("GET", self.url, access_token=access_token) + capabilities = channel.json_body["capabilities"] + + self.assertEqual(channel.code, HTTPStatus.OK) + self.assertTrue(capabilities["m.get_login_token"]["enabled"]) diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py index d80eea17d3..60099f8c59 100644 --- a/tests/rest/client/test_devices.py +++ b/tests/rest/client/test_devices.py @@ -13,12 +13,14 @@ # limitations under the License. from http import HTTPStatus +from twisted.internet.defer import ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import NotFoundError from synapse.rest import admin, devices, room, sync -from synapse.rest.client import account, login, register +from synapse.rest.client import account, keys, login, register from synapse.server import HomeServer +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from tests import unittest @@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): login.register_servlets, register.register_servlets, devices.register_servlets, + keys.register_servlets, ] + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.registration = hs.get_registration_handler() + self.message_handler = hs.get_device_message_handler() + def test_PUT(self) -> None: """Sanity-check that we can PUT a dehydrated device. @@ -226,7 +233,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): "device_data": { "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm", "account": "dehydrated_device", - } + }, + "device_keys": { + "user_id": "@alice:test", + "device_id": "device1", + "valid_until_ts": "80", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"} + }, + }, }, access_token=token, shorthand=False, @@ -234,3 +255,336 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body) device_id = channel.json_body.get("device_id") self.assertIsInstance(device_id, str) + + @unittest.override_config( + {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}} + ) + def test_dehydrate_msc3814(self) -> None: + user = self.register_user("mikey", "pass") + token = self.login(user, "pass", device_id="device1") + content: JsonDict = { + "device_data": { + "algorithm": "m.dehydration.v1.olm", + }, + "device_id": "device1", + "initial_device_display_name": "foo bar", + "device_keys": { + "user_id": "@mikey:test", + "device_id": "device1", + "valid_until_ts": "80", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"} + }, + }, + "fallback_keys": { + "alg1:device1": "f4llb4ckk3y", + "signed_<algorithm>:<device_id>": { + "fallback": "true", + "key": "f4llb4ckk3y", + "signatures": { + "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"} + }, + }, + }, + "one_time_keys": {"alg1:k1": "0net1m3k3y"}, + } + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + device_id = channel.json_body.get("device_id") + assert device_id is not None + self.assertIsInstance(device_id, str) + self.assertEqual("device1", device_id) + + # test that we can now GET the dehydrated device info + channel = self.make_request( + "GET", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + returned_device_id = channel.json_body.get("device_id") + self.assertEqual(returned_device_id, device_id) + device_data = channel.json_body.get("device_data") + expected_device_data = { + "algorithm": "m.dehydration.v1.olm", + } + self.assertEqual(device_data, expected_device_data) + + # test that the keys are correctly uploaded + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device1"], + }, + }, + token, + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["device_keys"][user][device_id]["keys"], + content["device_keys"]["keys"], + ) + # first claim should return the onetime key we uploaded + res = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}}, + }, + ) + # second claim should return fallback key + res2 = self.get_success( + self.hs.get_e2e_keys_handler().claim_one_time_keys( + {user: {device_id: {"alg1": 1}}}, + UserID.from_string(user), + timeout=None, + always_include_fallback_keys=False, + ) + ) + self.assertEqual( + res2, + { + "failures": {}, + "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}}, + }, + ) + + # create another device for the user + ( + new_device_id, + _, + _, + _, + ) = self.get_success( + self.registration.register_device( + user_id=user, + device_id=None, + initial_display_name="new device", + ) + ) + requester = create_requester(user, device_id=new_device_id) + + # Send a message to the dehydrated device + ensureDeferred( + self.message_handler.send_device_message( + requester=requester, + message_type="test.message", + messages={user: {device_id: {"body": "test_message"}}}, + ) + ) + self.pump() + + # make sure we can fetch the message with our dehydrated device id + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content={}, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + expected_content = {"body": "test_message"} + self.assertEqual(channel.json_body["events"][0]["content"], expected_content) + + # fetch messages again and make sure that the message was not deleted + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content={}, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["events"][0]["content"], expected_content) + next_batch_token = channel.json_body.get("next_batch") + + # make sure fetching messages with next batch token works - there are no unfetched + # messages so we should receive an empty array + content = {"next_batch": next_batch_token} + channel = self.make_request( + "POST", + f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["events"], []) + + # make sure we can delete the dehydrated device + channel = self.make_request( + "DELETE", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + + # ...and after deleting it is no longer available + channel = self.make_request( + "GET", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 401) + + @unittest.override_config( + {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}} + ) + def test_msc3814_dehydrated_device_delete_works(self) -> None: + user = self.register_user("mikey", "pass") + token = self.login(user, "pass", device_id="device1") + content: JsonDict = { + "device_data": { + "algorithm": "m.dehydration.v1.olm", + }, + "device_id": "device2", + "initial_device_display_name": "foo bar", + "device_keys": { + "user_id": "@mikey:test", + "device_id": "device2", + "valid_until_ts": "80", + "algorithms": [ + "m.olm.curve25519-aes-sha2", + ], + "keys": { + "<algorithm>:<device_id>": "<key_base64>", + }, + "signatures": { + "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"} + }, + }, + } + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + device_id = channel.json_body.get("device_id") + assert device_id is not None + self.assertIsInstance(device_id, str) + self.assertEqual("device2", device_id) + + # ensure that keys were uploaded and available + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device2"], + }, + }, + token, + ) + self.assertEqual( + channel.json_body["device_keys"][user]["device2"]["keys"], + { + "<algorithm>:<device_id>": "<key_base64>", + }, + ) + + # delete the dehydrated device + channel = self.make_request( + "DELETE", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + + # ensure that keys are no longer available for deleted device + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device2"], + }, + }, + token, + ) + self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}}) + + # check that an old device is deleted when user PUTs a new device + # First, create a device + content["device_id"] = "device3" + content["device_keys"]["device_id"] = "device3" + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + device_id = channel.json_body.get("device_id") + assert device_id is not None + self.assertIsInstance(device_id, str) + self.assertEqual("device3", device_id) + + # create a second device without deleting first device + content["device_id"] = "device4" + content["device_keys"]["device_id"] = "device4" + channel = self.make_request( + "PUT", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + content=content, + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + device_id = channel.json_body.get("device_id") + assert device_id is not None + self.assertIsInstance(device_id, str) + self.assertEqual("device4", device_id) + + # check that the second device that was created is what is returned when we GET + channel = self.make_request( + "GET", + "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device", + access_token=token, + shorthand=False, + ) + self.assertEqual(channel.code, 200) + returned_device_id = channel.json_body["device_id"] + self.assertEqual(returned_device_id, "device4") + + # and that if we query the keys for the first device they are not there + channel = self.make_request( + "POST", + "/_matrix/client/r0/keys/query", + { + "device_keys": { + user: ["device3"], + }, + }, + token, + ) + self.assertEqual(channel.json_body["device_keys"], {"@mikey:test": {}}) diff --git a/tests/rest/client/test_events.py b/tests/rest/client/test_events.py index 54df2a252c..141e0f57a3 100644 --- a/tests/rest/client/test_events.py +++ b/tests/rest/client/test_events.py @@ -45,7 +45,7 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase): hs = self.setup_test_homeserver(config=config) - hs.get_federation_handler = Mock() # type: ignore[assignment] + hs.get_federation_handler = Mock() # type: ignore[method-assign] return hs diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py index 9faa9de050..90a8df147c 100644 --- a/tests/rest/client/test_filter.py +++ b/tests/rest/client/test_filter.py @@ -46,7 +46,9 @@ class FilterTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"filter_id": "0"}) filter = self.get_success( - self.store.get_user_filter(user_localpart="apple", filter_id=0) + self.store.get_user_filter( + user_id=UserID.from_string(FilterTestCase.user_id), filter_id=0 + ) ) self.pump() self.assertEqual(filter, self.EXAMPLE_FILTER) @@ -63,14 +65,14 @@ class FilterTestCase(unittest.HomeserverTestCase): def test_add_filter_non_local_user(self) -> None: _is_mine = self.hs.is_mine - self.hs.is_mine = lambda target_user: False # type: ignore[assignment] + self.hs.is_mine = lambda target_user: False # type: ignore[method-assign] channel = self.make_request( "POST", "/_matrix/client/r0/user/%s/filter" % (self.user_id), self.EXAMPLE_FILTER_JSON, ) - self.hs.is_mine = _is_mine # type: ignore[assignment] + self.hs.is_mine = _is_mine # type: ignore[method-assign] self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index dc32982e22..768d7ad4c2 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -13,11 +13,12 @@ # limitations under the License. import time import urllib.parse -from typing import Any, Dict, List, Optional +from typing import Any, Collection, Dict, List, Optional, Tuple, Union from unittest.mock import Mock from urllib.parse import urlencode import pymacaroons +from typing_extensions import Literal from twisted.test.proto_helpers import MemoryReactor from twisted.web.resource import Resource @@ -26,11 +27,12 @@ import synapse.rest.admin from synapse.api.constants import ApprovalNoticeMedium, LoginType from synapse.api.errors import Codes from synapse.appservice import ApplicationService +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login, logout, register from synapse.rest.client.account import WhoamiRestServlet from synapse.rest.synapse.client import build_synapse_client_resource_tree from synapse.server import HomeServer -from synapse.types import create_requester +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests import unittest @@ -88,6 +90,56 @@ ADDITIONAL_LOGIN_FLOWS = [ ] +class TestSpamChecker: + def __init__(self, config: None, api: ModuleApi): + api.register_spam_checker_callbacks( + check_login_for_spam=self.check_login_for_spam, + ) + + @staticmethod + def parse_config(config: JsonDict) -> None: + return None + + async def check_login_for_spam( + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + request_info: Collection[Tuple[Optional[str], str]], + auth_provider_id: Optional[str] = None, + ) -> Union[ + Literal["NOT_SPAM"], + Tuple["synapse.module_api.errors.Codes", JsonDict], + ]: + return "NOT_SPAM" + + +class DenyAllSpamChecker: + def __init__(self, config: None, api: ModuleApi): + api.register_spam_checker_callbacks( + check_login_for_spam=self.check_login_for_spam, + ) + + @staticmethod + def parse_config(config: JsonDict) -> None: + return None + + async def check_login_for_spam( + self, + user_id: str, + device_id: Optional[str], + initial_display_name: Optional[str], + request_info: Collection[Tuple[Optional[str], str]], + auth_provider_id: Optional[str] = None, + ) -> Union[ + Literal["NOT_SPAM"], + Tuple["synapse.module_api.errors.Codes", JsonDict], + ]: + # Return an odd set of values to ensure that they get correctly passed + # to the client. + return Codes.LIMIT_EXCEEDED, {"extra": "value"} + + class LoginRestServletTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets_for_client_rest_resource, @@ -117,16 +169,17 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "account": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_address(self) -> None: # Create different users so we're sure not to be bothered by the per-user # ratelimiter. - for i in range(0, 6): + for i in range(6): self.register_user("kermit" + str(i), "monkey") - for i in range(0, 6): + for i in range(6): params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit" + str(i)}, @@ -137,12 +190,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -165,13 +221,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # which sets these values to 10000, but as we're overriding the entire # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account(self) -> None: self.register_user("kermit", "monkey") - for i in range(0, 6): + for i in range(6): params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, @@ -182,12 +239,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 200, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0) @@ -210,13 +270,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): # rc_login dict here, we need to set this manually as well "address": {"per_second": 10000, "burst_count": 10000}, "failed_attempts": {"per_second": 0.17, "burst_count": 5}, - } + }, + "experimental_features": {"msc4041_enabled": True}, } ) def test_POST_ratelimiting_per_account_failed_attempts(self) -> None: self.register_user("kermit", "monkey") - for i in range(0, 6): + for i in range(6): params = { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": "kermit"}, @@ -227,12 +288,15 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): if i == 5: self.assertEqual(channel.code, 429, msg=channel.result) retry_after_ms = int(channel.json_body["retry_after_ms"]) + retry_header = channel.headers.getRawHeaders("Retry-After") else: self.assertEqual(channel.code, 403, msg=channel.result) # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower # than 1min. - self.assertTrue(retry_after_ms < 6000) + self.assertLess(retry_after_ms, 6000) + assert retry_header + self.assertLessEqual(int(retry_header[0]), 6) self.reactor.advance(retry_after_ms / 1000.0 + 1.0) @@ -446,6 +510,82 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): ApprovalNoticeMedium.NONE, channel.json_body["approval_notice_medium"] ) + def test_get_login_flows_with_login_via_existing_disabled(self) -> None: + """GET /login should return m.login.token without get_login_token""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + flows = {flow["type"]: flow for flow in channel.json_body["flows"]} + self.assertNotIn("m.login.token", flows) + + @override_config({"login_via_existing_session": {"enabled": True}}) + def test_get_login_flows_with_login_via_existing_enabled(self) -> None: + """GET /login should return m.login.token with get_login_token true""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + self.assertCountEqual( + channel.json_body["flows"], + [ + {"type": "m.login.token", "get_login_token": True}, + {"type": "m.login.password"}, + {"type": "m.login.application_service"}, + ], + ) + + @override_config( + { + "modules": [ + { + "module": TestSpamChecker.__module__ + + "." + + TestSpamChecker.__qualname__ + } + ] + } + ) + def test_spam_checker_allow(self) -> None: + """Check that that adding a spam checker doesn't break login.""" + self.register_user("kermit", "monkey") + + body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} + + channel = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(channel.code, 200, channel.result) + + @override_config( + { + "modules": [ + { + "module": DenyAllSpamChecker.__module__ + + "." + + DenyAllSpamChecker.__qualname__ + } + ] + } + ) + def test_spam_checker_deny(self) -> None: + """Check that login""" + + self.register_user("kermit", "monkey") + + body = {"type": "m.login.password", "user": "kermit", "password": "monkey"} + + channel = self.make_request( + "POST", + "/_matrix/client/r0/login", + body, + ) + self.assertEqual(channel.code, 403, channel.result) + self.assertLessEqual( + {"errcode": Codes.LIMIT_EXCEEDED, "extra": "value"}.items(), + channel.json_body.items(), + ) + @skip_unless(has_saml2 and HAS_OIDC, "Requires SAML2 and OIDC") class MultiSSOTestCase(unittest.HomeserverTestCase): diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py index b8187db982..f05e619aa8 100644 --- a/tests/rest/client/test_login_token_request.py +++ b/tests/rest/client/test_login_token_request.py @@ -15,14 +15,14 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin -from synapse.rest.client import login, login_token_request +from synapse.rest.client import login, login_token_request, versions from synapse.server import HomeServer from synapse.util import Clock from tests import unittest from tests.unittest import override_config -endpoint = "/_matrix/client/unstable/org.matrix.msc3882/login/token" +GET_TOKEN_ENDPOINT = "/_matrix/client/v1/login/get_token" class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): @@ -30,6 +30,7 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): login.register_servlets, admin.register_servlets, login_token_request.register_servlets, + versions.register_servlets, # TODO: remove once unstable revision 0 support is removed ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: @@ -46,26 +47,26 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.password = "password" def test_disabled(self) -> None: - channel = self.make_request("POST", endpoint, {}, access_token=None) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None) self.assertEqual(channel.code, 404) self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", endpoint, {}, access_token=token) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token) self.assertEqual(channel.code, 404) - @override_config({"experimental_features": {"msc3882_enabled": True}}) + @override_config({"login_via_existing_session": {"enabled": True}}) def test_require_auth(self) -> None: - channel = self.make_request("POST", endpoint, {}, access_token=None) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=None) self.assertEqual(channel.code, 401) - @override_config({"experimental_features": {"msc3882_enabled": True}}) + @override_config({"login_via_existing_session": {"enabled": True}}) def test_uia_on(self) -> None: user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", endpoint, {}, access_token=token) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token) self.assertEqual(channel.code, 401) self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) @@ -80,9 +81,9 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): }, } - channel = self.make_request("POST", endpoint, uia, access_token=token) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, uia, access_token=token) self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["expires_in"], 300) + self.assertEqual(channel.json_body["expires_in_ms"], 300000) login_token = channel.json_body["login_token"] @@ -95,15 +96,15 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.json_body["user_id"], user_id) @override_config( - {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + {"login_via_existing_session": {"enabled": True, "require_ui_auth": False}} ) def test_uia_off(self) -> None: user_id = self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", endpoint, {}, access_token=token) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token) self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body["expires_in"], 300) + self.assertEqual(channel.json_body["expires_in_ms"], 300000) login_token = channel.json_body["login_token"] @@ -117,10 +118,10 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): @override_config( { - "experimental_features": { - "msc3882_enabled": True, - "msc3882_ui_auth": False, - "msc3882_token_timeout": "15s", + "login_via_existing_session": { + "enabled": True, + "require_ui_auth": False, + "token_timeout": "15s", } } ) @@ -128,6 +129,40 @@ class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): self.register_user(self.user, self.password) token = self.login(self.user, self.password) - channel = self.make_request("POST", endpoint, {}, access_token=token) + channel = self.make_request("POST", GET_TOKEN_ENDPOINT, {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in_ms"], 15000) + + @override_config( + { + "login_via_existing_session": { + "enabled": True, + "require_ui_auth": False, + "token_timeout": "15s", + } + } + ) + def test_unstable_support(self) -> None: + # TODO: remove support for unstable MSC3882 is no longer needed + + # check feature is advertised in versions response: + channel = self.make_request( + "GET", "/_matrix/client/versions", {}, access_token=None + ) + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body["unstable_features"]["org.matrix.msc3882"], True + ) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + # check feature is available via the unstable endpoint and returns an expires_in value in seconds + channel = self.make_request( + "POST", + "/_matrix/client/unstable/org.matrix.msc3882/login/token", + {}, + access_token=token, + ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body["expires_in"], 15) diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py index 0b8fcb0c47..524ea6047e 100644 --- a/tests/rest/client/test_models.py +++ b/tests/rest/client/test_models.py @@ -12,12 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import unittest as stdlib_unittest +from typing import TYPE_CHECKING -from pydantic import BaseModel, ValidationError from typing_extensions import Literal +from synapse._pydantic_compat import HAS_PYDANTIC_V2 from synapse.rest.client.models import EmailRequestTokenBody +if TYPE_CHECKING or HAS_PYDANTIC_V2: + from pydantic.v1 import BaseModel, ValidationError +else: + from pydantic import BaseModel, ValidationError + class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase): class Model(BaseModel): diff --git a/tests/rest/client/test_notifications.py b/tests/rest/client/test_notifications.py index 700f6587a0..41ceb3db51 100644 --- a/tests/rest/client/test_notifications.py +++ b/tests/rest/client/test_notifications.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room from synapse.server import HomeServer from synapse.util import Clock -from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase @@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: # Mock out the calls over federation. fed_transport_client = Mock(spec=["send_transaction"]) - fed_transport_client.send_transaction = simple_async_mock({}) + fed_transport_client.send_transaction = AsyncMock(return_value={}) return self.setup_test_homeserver( federation_transport_client=fed_transport_client, diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index dcbb125a3b..66b387cea3 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from http import HTTPStatus -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -23,7 +23,6 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -36,11 +35,10 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.presence_handler = Mock(spec=PresenceHandler) - self.presence_handler.set_state.return_value = make_awaitable(None) + self.presence_handler.set_state = AsyncMock(return_value=None) hs = self.setup_test_homeserver( "red", - federation_http_client=None, federation_client=Mock(), presence_handler=self.presence_handler, ) diff --git a/tests/rest/client/test_profile.py b/tests/rest/client/test_profile.py index 27c93ad761..ecae092b47 100644 --- a/tests/rest/client/test_profile.py +++ b/tests/rest/client/test_profile.py @@ -68,6 +68,18 @@ class ProfileTestCase(unittest.HomeserverTestCase): res = self._get_displayname() self.assertEqual(res, "test") + def test_set_displayname_with_extra_spaces(self) -> None: + channel = self.make_request( + "PUT", + "/profile/%s/displayname" % (self.owner,), + content={"displayname": " test "}, + access_token=self.owner_tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + res = self._get_displayname() + self.assertEqual(res, "test") + def test_set_displayname_noauth(self) -> None: channel = self.make_request( "PUT", diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py index ec5b19c33f..3de43760db 100644 --- a/tests/rest/client/test_public_rooms.py +++ b/tests/rest/client/test_public_rooms.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple from twisted.test.proto_helpers import MemoryReactor from synapse.rest import admin, login, room from synapse.server import HomeServer -from synapse.types import PublicRoom, ThirdPartyInstanceID from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -87,6 +85,7 @@ class PublicRoomsTestCase(HomeserverTestCase): def test_pagination_limit_1(self) -> None: returned_rooms = set() + channel = None for i in range(10): next_batch = None if i == 0 else channel.json_body["next_batch"] since_query_str = f"&since={next_batch}" if next_batch else "" @@ -115,6 +114,7 @@ class PublicRoomsTestCase(HomeserverTestCase): def test_pagination_limit_2(self) -> None: returned_rooms = set() + channel = None for i in range(5): next_batch = None if i == 0 else channel.json_body["next_batch"] since_query_str = f"&since={next_batch}" if next_batch else "" diff --git a/tests/rest/client/test_push_rule_attrs.py b/tests/rest/client/test_push_rule_attrs.py index 4f875b9289..5aca74475f 100644 --- a/tests/rest/client/test_push_rule_attrs.py +++ b/tests/rest/client/test_push_rule_attrs.py @@ -412,3 +412,70 @@ class PushRuleAttributesTestCase(HomeserverTestCase): ) self.assertEqual(channel.code, 404) self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND) + + def test_contains_user_name(self) -> None: + """ + Tests that `contains_user_name` rule is present and have proper value in `pattern`. + """ + username = "bob" + self.register_user(username, "pass") + token = self.login(username, "pass") + + channel = self.make_request( + "GET", + "/pushrules/global/content/.m.rule.contains_user_name", + access_token=token, + ) + + self.assertEqual(channel.code, 200) + + self.assertEqual( + { + "rule_id": ".m.rule.contains_user_name", + "default": True, + "enabled": True, + "pattern": username, + "actions": [ + "notify", + {"set_tweak": "highlight"}, + {"set_tweak": "sound", "value": "default"}, + ], + }, + channel.json_body, + ) + + def test_is_user_mention(self) -> None: + """ + Tests that `is_user_mention` rule is present and have proper value in `value`. + """ + user = self.register_user("bob", "pass") + token = self.login("bob", "pass") + + channel = self.make_request( + "GET", + "/pushrules/global/override/.m.rule.is_user_mention", + access_token=token, + ) + + self.assertEqual(channel.code, 200) + + self.assertEqual( + { + "rule_id": ".m.rule.is_user_mention", + "default": True, + "enabled": True, + "conditions": [ + { + "kind": "event_property_contains", + "key": "content.m\\.mentions.user_ids", + "value": user, + } + ], + "actions": [ + "notify", + {"set_tweak": "highlight"}, + {"set_tweak": "sound", "value": "default"}, + ], + }, + channel.json_body, + ) diff --git a/tests/rest/client/test_read_marker.py b/tests/rest/client/test_read_marker.py index 0eedcdb476..5cdd5694a0 100644 --- a/tests/rest/client/test_read_marker.py +++ b/tests/rest/client/test_read_marker.py @@ -131,9 +131,6 @@ class ReadMarkerTestCase(unittest.HomeserverTestCase): event = self.get_success(self.store.get_event(event_id_1, allow_none=True)) assert event is None - # TODO See https://github.com/matrix-org/synapse/issues/13476 - self.store.get_event_ordering.invalidate_all() - # Test moving the read marker to a newer event event_id_2 = send_message() channel = self.make_request( diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py index 2a7fcea386..ec638c89b7 100644 --- a/tests/rest/client/test_receipts.py +++ b/tests/rest/client/test_receipts.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.rest.client import login, receipts, register +from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, ReceiptTypes +from synapse.rest.client import login, receipts, room, sync from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -24,30 +29,113 @@ from tests import unittest class ReceiptsTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - register.register_servlets, receipts.register_servlets, synapse.rest.admin.register_servlets, + room.register_servlets, + sync.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) def test_send_receipt(self) -> None: + # Send a message. + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + def test_send_receipt_unknown_event(self) -> None: + """Receipts sent for unknown events are ignored to not break message retention.""" + # Attempt to send a receipt to an unknown room. channel = self.make_request( "POST", "/rooms/!abc:beep/receipt/m.read/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIsNone(self._get_read_receipt()) + + # Attempt to send a receipt to an unknown event. + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/$def", + content={}, + access_token=self.tok2, ) self.assertEqual(channel.code, 200, channel.result) + self.assertIsNone(self._get_read_receipt()) + + def test_send_receipt_unviewable_event(self) -> None: + """Receipts sent for unviewable events are errors.""" + # Create a room where new users can't see events from before their join + # & send events into it. + room_id = self.helper.create_room_as( + self.user_id, + tok=self.tok, + extra_content={ + "preset": "private_chat", + "initial_state": [ + { + "content": {"history_visibility": HistoryVisibility.JOINED}, + "state_key": "", + "type": EventTypes.RoomHistoryVisibility, + } + ], + }, + ) + res = self.helper.send(room_id, body="hello", tok=self.tok) + + # Attempt to send a receipt from the wrong user. + channel = self.make_request( + "POST", + f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + content={}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 403, channel.result) + + # Join the user to the room, but they still can't see the event. + self.helper.invite(room_id, self.user_id, self.user2, tok=self.tok) + self.helper.join(room=room_id, user=self.user2, tok=self.tok2) + + channel = self.make_request( + "POST", + f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + content={}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 403, channel.result) def test_send_receipt_invalid_room_id(self) -> None: channel = self.make_request( "POST", "/rooms/not-a-room-id/receipt/m.read/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -59,7 +147,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "POST", "/rooms/!abc:beep/receipt/m.read/not-an-event-id", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -71,6 +159,123 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "POST", "/rooms/!abc:beep/receipt/invalid-receipt-type/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) + + def test_private_read_receipts(self) -> None: + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt to tell the server the first user's message was read + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's private read receipt + self.assertIsNone(self._get_read_receipt()) + + def test_public_receipt_can_override_private(self) -> None: + """ + Sending a public read receipt to the same event which has a private read + receipt should cause that receipt to become public. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertIsNone(self._get_read_receipt()) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we did override the private read receipt + self.assertNotEqual(self._get_read_receipt(), None) + + def test_private_receipt_cannot_override_public(self) -> None: + """ + Sending a private read receipt to the same event which has a public read + receipt should cause no change. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we didn't override the public read receipt + self.assertIsNone(self._get_read_receipt()) + + def test_read_receipt_with_empty_body_is_rejected(self) -> None: + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", + access_token=self.tok2, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) + + def _get_read_receipt(self) -> Optional[JsonDict]: + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event: JsonDict) -> bool: + return event["type"] == EduTypes.RECEIPT + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + if channel.json_body.get("rooms", None) is None: + return None + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py index 84a60c0b07..4e0a387bd3 100644 --- a/tests/rest/client/test_redactions.py +++ b/tests/rest/client/test_redactions.py @@ -13,13 +13,17 @@ # limitations under the License. from typing import List, Optional +from parameterized import parameterized + from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, RelationTypes -from synapse.api.room_versions import RoomVersions +from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.rest import admin from synapse.rest.client import login, room, sync from synapse.server import HomeServer +from synapse.storage._base import db_to_json +from synapse.storage.database import LoggingTransaction from synapse.types import JsonDict from synapse.util import Clock @@ -217,9 +221,9 @@ class RedactionsTestCase(HomeserverTestCase): self._redact_event(self.mod_access_token, self.room_id, msg_id) @override_config({"experimental_features": {"msc3912_enabled": True}}) - def test_redact_relations(self) -> None: - """Tests that we can redact the relations of an event at the same time as the - event itself. + def test_redact_relations_with_types(self) -> None: + """Tests that we can redact the relations of an event of specific types + at the same time as the event itself. """ # Send a root event. res = self.helper.send_event( @@ -318,6 +322,104 @@ class RedactionsTestCase(HomeserverTestCase): self.assertNotIn("redacted_because", event_dict, event_dict) @override_config({"experimental_features": {"msc3912_enabled": True}}) + def test_redact_all_relations(self) -> None: + """Tests that we can redact all the relations of an event at the same time as the + event itself. + """ + # Send a root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={"msgtype": "m.text", "body": "hello"}, + tok=self.mod_access_token, + ) + root_event_id = res["event_id"] + + # Send an edit to this root event. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "body": " * hello world", + "m.new_content": { + "body": "hello world", + "msgtype": "m.text", + }, + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.REPLACE, + }, + "msgtype": "m.text", + }, + tok=self.mod_access_token, + ) + edit_event_id = res["event_id"] + + # Also send a threaded message whose root is the same as the edit's. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Message, + content={ + "msgtype": "m.text", + "body": "message 1", + "m.relates_to": { + "event_id": root_event_id, + "rel_type": RelationTypes.THREAD, + }, + }, + tok=self.mod_access_token, + ) + threaded_event_id = res["event_id"] + + # Also send a reaction, again with the same root. + res = self.helper.send_event( + room_id=self.room_id, + type=EventTypes.Reaction, + content={ + "m.relates_to": { + "rel_type": RelationTypes.ANNOTATION, + "event_id": root_event_id, + "key": "👍", + } + }, + tok=self.mod_access_token, + ) + reaction_event_id = res["event_id"] + + # Redact the root event, specifying that we also want to delete all events that + # relate to it. + self._redact_event( + self.mod_access_token, + self.room_id, + root_event_id, + with_relations=["*"], + ) + + # Check that the root event got redacted. + event_dict = self.helper.get_event( + self.room_id, root_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the edit got redacted. + event_dict = self.helper.get_event( + self.room_id, edit_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the threaded message got redacted. + event_dict = self.helper.get_event( + self.room_id, threaded_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + # Check that the reaction got redacted. + event_dict = self.helper.get_event( + self.room_id, reaction_event_id, self.mod_access_token + ) + self.assertIn("redacted_because", event_dict, event_dict) + + @override_config({"experimental_features": {"msc3912_enabled": True}}) def test_redact_relations_no_perms(self) -> None: """Tests that, when redacting a message along with its relations, if not all the related messages can be redacted because of insufficient permissions, the @@ -469,35 +571,81 @@ class RedactionsTestCase(HomeserverTestCase): self.assertIn("body", event_dict["content"], event_dict) self.assertEqual("I'm in a thread!", event_dict["content"]["body"]) - def test_content_redaction(self) -> None: - """MSC2174 moved the redacts property to the content.""" + @parameterized.expand( + [ + # Tuples of: + # Room version + # Boolean: True if the redaction event content should include the event ID. + # Boolean: true if the resulting redaction event is expected to include the + # event ID in the content. + (RoomVersions.V10, False, False), + (RoomVersions.V11, True, True), + (RoomVersions.V11, False, True), + ] + ) + def test_redaction_content( + self, room_version: RoomVersion, include_content: bool, expect_content: bool + ) -> None: + """ + Room version 11 moved the redacts property to the content. + + Ensure that the event gets created properly and that the Client-Server + API servers the proper backwards-compatible version. + """ # Create a room with the newer room version. room_id = self.helper.create_room_as( self.mod_user_id, tok=self.mod_access_token, - room_version=RoomVersions.MSC2176.identifier, + room_version=room_version.identifier, ) # Create an event. b = self.helper.send(room_id=room_id, tok=self.mod_access_token) event_id = b["event_id"] - # Attempt to redact it with a bogus event ID. - self._redact_event( + # Ensure the event ID in the URL and the content must match. + if include_content: + self._redact_event( + self.mod_access_token, + room_id, + event_id, + expect_code=400, + content={"redacts": "foo"}, + ) + + # Redact it for real. + result = self._redact_event( self.mod_access_token, room_id, event_id, - expect_code=400, - content={"redacts": "foo"}, + content={"redacts": event_id} if include_content else {}, ) - - # Redact it for real. - self._redact_event(self.mod_access_token, room_id, event_id) + redaction_event_id = result["event_id"] # Sync the room, to get the id of the create event timeline = self._sync_room_timeline(self.mod_access_token, room_id) redact_event = timeline[-1] self.assertEqual(redact_event["type"], EventTypes.Redaction) - # The redacts key should be in the content. - self.assertNotIn("redacts", redact_event) - self.assertEquals(redact_event["content"]["redacts"], event_id) + # The redacts key should be in the content and the redacts keys. + self.assertEqual(redact_event["content"]["redacts"], event_id) + self.assertEqual(redact_event["redacts"], event_id) + + # But it isn't actually part of the event. + def get_event(txn: LoggingTransaction) -> JsonDict: + return db_to_json( + main_datastore._fetch_event_rows(txn, [redaction_event_id])[ + redaction_event_id + ].json + ) + + main_datastore = self.hs.get_datastores().main + event_json = self.get_success( + main_datastore.db_pool.runInteraction("get_event", get_event) + ) + self.assertEqual(event_json["type"], EventTypes.Redaction) + if expect_content: + self.assertNotIn("redacts", event_json) + self.assertEqual(event_json["content"]["redacts"], event_id) + else: + self.assertEqual(event_json["redacts"], event_id) + self.assertNotIn("redacts", event_json["content"]) diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index b228dba861..ba4e017a0e 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -75,7 +75,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, msg=channel.result) det_data = {"user_id": user_id, "home_server": self.hs.hostname} - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) def test_POST_appservice_registration_no_type(self) -> None: as_token = "i_am_an_app_service" @@ -136,7 +136,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) @override_config({"enable_registration": False}) def test_POST_disabled_registration(self) -> None: @@ -157,7 +157,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"} self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) def test_POST_disabled_guest_registration(self) -> None: self.hs.config.registration.allow_guest_access = False @@ -169,7 +169,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting_guest(self) -> None: - for i in range(0, 6): + for i in range(6): url = self.url + b"?kind=guest" channel = self.make_request(b"POST", url, b"{}") @@ -187,7 +187,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}}) def test_POST_ratelimiting(self) -> None: - for i in range(0, 6): + for i in range(6): request_data = { "username": "kermit" + str(i), "password": "monkey", @@ -267,7 +267,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): "device_id": device_id, } self.assertEqual(channel.code, 200, msg=channel.result) - self.assertDictContainsSubset(det_data, channel.json_body) + self.assertLessEqual(det_data.items(), channel.json_body.items()) # Check the `completed` counter has been incremented and pending is 0 res = self.get_success( @@ -1223,7 +1223,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase): def test_GET_ratelimiting(self) -> None: token = "1234" - for i in range(0, 6): + for i in range(6): channel = self.make_request( b"GET", f"{self.url}?token={token}", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 75439416c1..61773fb28c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import urllib.parse from typing import Any, Callable, Dict, List, Optional, Tuple -from unittest.mock import patch +from unittest.mock import AsyncMock, patch from twisted.test.proto_helpers import MemoryReactor @@ -28,7 +28,6 @@ from synapse.util import Clock from tests import unittest from tests.server import FakeChannel -from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_event from tests.unittest import override_config @@ -129,7 +128,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) return [ev["event_id"] for ev in channel.json_body["chunk"]] def _get_bundled_aggregations(self) -> JsonDict: @@ -142,7 +141,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): f"/_matrix/client/v3/rooms/{self.room}/event/{self.parent_id}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["unsigned"].get("m.relations", {}) def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: @@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase): # Disable the validation to pretend this came over federation. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new=lambda self, event: make_awaitable(None), + new_callable=AsyncMock, + return_value=None, ): # Generate a various relations from a different room. self.get_success( @@ -570,7 +570,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(200, channel.code, channel.json_body) event_result = channel.json_body - self.assertDictContainsSubset(original_body, event_result["content"]) + self.assertLessEqual(original_body.items(), event_result["content"].items()) # also check /context, which returns the *edited* event channel = self.make_request( @@ -587,14 +587,14 @@ class RelationsTestCase(BaseRelationsTestCase): (context_result, "/context"), ): # The reference metadata should still be intact. - self.assertDictContainsSubset( + self.assertLessEqual( { "m.relates_to": { "event_id": self.parent_id, "rel_type": "m.reference", } - }, - result_event_dict["content"], + }.items(), + result_event_dict["content"].items(), desc, ) @@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # not an event the Client-Server API will allow.. with patch( "synapse.handlers.message.EventCreationHandler._validate_event_relation", - new=lambda self, event: make_awaitable(None), + new_callable=AsyncMock, + return_value=None, ): # Create a sub-thread off the thread, which is not allowed. self._send_relation( @@ -1371,9 +1372,11 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): latest_event_in_thread = thread_summary["latest_event"] # The latest event in the thread should have the edit appear under the # bundled aggregations. - self.assertDictContainsSubset( - {"event_id": edit_event_id, "sender": "@alice:test"}, - latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE], + self.assertLessEqual( + {"event_id": edit_event_id, "sender": "@alice:test"}.items(), + latest_event_in_thread["unsigned"]["m.relations"][ + RelationTypes.REPLACE + ].items(), ) def test_aggregation_get_event_for_annotation(self) -> None: @@ -1602,7 +1605,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threads = channel.json_body["chunk"] return [ ( @@ -1634,11 +1637,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): ################################################## # Check the test data is configured as expected. # ################################################## - self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 3, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 3, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # The latest event is the last sent event. self.assertEqual( @@ -1655,11 +1658,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(thread_replies.pop()) # The thread should still exist, but the latest event should be updated. - self.assertEquals(self._get_related_events(), list(reversed(thread_replies))) + self.assertEqual(self._get_related_events(), list(reversed(thread_replies))) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 2, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 2, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1674,11 +1677,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(thread_replies.pop(0)) # Nothing should have changed (except the thread count). - self.assertEquals(self._get_related_events(), thread_replies) + self.assertEqual(self._get_related_events(), thread_replies) relations = self._get_bundled_aggregations() - self.assertDictContainsSubset( - {"count": 1, "current_user_participated": True}, - relations[RelationTypes.THREAD], + self.assertLessEqual( + {"count": 1, "current_user_participated": True}.items(), + relations[RelationTypes.THREAD].items(), ) # And the latest event is the last unredacted event. self.assertEqual( @@ -1691,11 +1694,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # Redact the last remaining event. # #################################### self._redact(thread_replies.pop(0)) - self.assertEquals(thread_replies, []) + self.assertEqual(thread_replies, []) # The event should no longer be considered a thread. - self.assertEquals(self._get_related_events(), []) - self.assertEquals(self._get_bundled_aggregations(), {}) + self.assertEqual(self._get_related_events(), []) + self.assertEqual(self._get_bundled_aggregations(), {}) self.assertEqual(self._get_threads(), []) def test_redact_parent_edit(self) -> None: @@ -1749,8 +1752,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # The relations are returned. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() - self.assertEquals(event_ids, [related_event_id]) - self.assertEquals( + self.assertEqual(event_ids, [related_event_id]) + self.assertEqual( relations[RelationTypes.REFERENCE], {"chunk": [{"event_id": related_event_id}]}, ) @@ -1772,13 +1775,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase): # The unredacted relation should still exist. event_ids = self._get_related_events() relations = self._get_bundled_aggregations() - self.assertEquals(len(event_ids), 1) - self.assertDictContainsSubset( + self.assertEqual(len(event_ids), 1) + self.assertLessEqual( { "count": 1, "current_user_participated": True, - }, - relations[RelationTypes.THREAD], + }.items(), + relations[RelationTypes.THREAD].items(), ) self.assertEqual( relations[RelationTypes.THREAD]["latest_event"]["event_id"], @@ -1816,7 +1819,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) threads = self._get_threads(channel.json_body) self.assertEqual(threads, [(thread_2, reply_2), (thread_1, reply_1)]) @@ -1829,7 +1832,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) # Tuple of (thread ID, latest event ID) for each thread. threads = self._get_threads(channel.json_body) self.assertEqual(threads, [(thread_1, reply_3), (thread_2, reply_2)]) @@ -1850,7 +1853,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2]) @@ -1864,7 +1867,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?limit=1&from={next_batch}", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_1], channel.json_body) @@ -1899,7 +1902,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual( thread_roots, [thread_3, thread_2, thread_1], channel.json_body @@ -1911,7 +1914,7 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads?include=participated", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) @@ -1943,6 +1946,6 @@ class ThreadsTestCase(BaseRelationsTestCase): f"/_matrix/client/v1/rooms/{self.room}/threads", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) + self.assertEqual(200, channel.code, channel.json_body) thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] self.assertEqual(thread_roots, [thread_1], channel.json_body) diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py deleted file mode 100644 index 9d5cb60d16..0000000000 --- a/tests/rest/client/test_room_batch.py +++ /dev/null @@ -1,302 +0,0 @@ -import logging -from typing import List, Tuple -from unittest.mock import Mock, patch - -from twisted.test.proto_helpers import MemoryReactor - -from synapse.api.constants import EventContentFields, EventTypes -from synapse.appservice import ApplicationService -from synapse.rest import admin -from synapse.rest.client import login, register, room, room_batch, sync -from synapse.server import HomeServer -from synapse.types import JsonDict, RoomStreamToken -from synapse.util import Clock - -from tests import unittest - -logger = logging.getLogger(__name__) - - -def _create_join_state_events_for_batch_send_request( - virtual_user_ids: List[str], - insert_time: int, -) -> List[JsonDict]: - return [ - { - "type": EventTypes.Member, - "sender": virtual_user_id, - "origin_server_ts": insert_time, - "content": { - "membership": "join", - "displayname": "display-name-for-%s" % (virtual_user_id,), - }, - "state_key": virtual_user_id, - } - for virtual_user_id in virtual_user_ids - ] - - -def _create_message_events_for_batch_send_request( - virtual_user_id: str, insert_time: int, count: int -) -> List[JsonDict]: - return [ - { - "type": EventTypes.Message, - "sender": virtual_user_id, - "origin_server_ts": insert_time, - "content": { - "msgtype": "m.text", - "body": "Historical %d" % (i), - EventContentFields.MSC2716_HISTORICAL: True, - }, - } - for i in range(count) - ] - - -class RoomBatchTestCase(unittest.HomeserverTestCase): - """Test importing batches of historical messages.""" - - servlets = [ - admin.register_servlets_for_client_rest_resource, - room_batch.register_servlets, - room.register_servlets, - register.register_servlets, - login.register_servlets, - sync.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - self.appservice = ApplicationService( - token="i_am_an_app_service", - id="1234", - namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, - # Note: this user does not have to match the regex above - sender="@as_main:test", - ) - - mock_load_appservices = Mock(return_value=[self.appservice]) - with patch( - "synapse.storage.databases.main.appservice.load_appservices", - mock_load_appservices, - ): - hs = self.setup_test_homeserver(config=config) - return hs - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.clock = clock - self._storage_controllers = hs.get_storage_controllers() - - self.virtual_user_id, _ = self.register_appservice_user( - "as_user_potato", self.appservice.token - ) - - def _create_test_room(self) -> Tuple[str, str, str, str]: - room_id = self.helper.create_room_as( - self.appservice.sender, tok=self.appservice.token - ) - - res_a = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "A", - }, - tok=self.appservice.token, - ) - event_id_a = res_a["event_id"] - - res_b = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "B", - }, - tok=self.appservice.token, - ) - event_id_b = res_b["event_id"] - - res_c = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "C", - }, - tok=self.appservice.token, - ) - event_id_c = res_c["event_id"] - - return room_id, event_id_a, event_id_b, event_id_c - - @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_same_state_groups_for_whole_historical_batch(self) -> None: - """Make sure that when using the `/batch_send` endpoint to import a - bunch of historical messages, it re-uses the same `state_group` across - the whole batch. This is an easy optimization to make sure we're getting - right because the state for the whole batch is contained in - `state_events_at_start` and can be shared across everything. - """ - - time_before_room = int(self.clock.time_msec()) - room_id, event_id_a, _, _ = self._create_test_room() - - channel = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" - % (room_id, event_id_a), - content={ - "events": _create_message_events_for_batch_send_request( - self.virtual_user_id, time_before_room, 3 - ), - "state_events_at_start": _create_join_state_events_for_batch_send_request( - [self.virtual_user_id], time_before_room - ), - }, - access_token=self.appservice.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - # Get the historical event IDs that we just imported - historical_event_ids = channel.json_body["event_ids"] - self.assertEqual(len(historical_event_ids), 3) - - # Fetch the state_groups - state_group_map = self.get_success( - self._storage_controllers.state.get_state_groups_ids( - room_id, historical_event_ids - ) - ) - - # We expect all of the historical events to be using the same state_group - # so there should only be a single state_group here! - self.assertEqual( - len(state_group_map.keys()), - 1, - "Expected a single state_group to be returned by saw state_groups=%s" - % (state_group_map.keys(),), - ) - - @unittest.override_config({"experimental_features": {"msc2716_enabled": True}}) - def test_sync_while_batch_importing(self) -> None: - """ - Make sure that /sync correctly returns full room state when a user joins - during ongoing batch backfilling. - See: https://github.com/matrix-org/synapse/issues/12281 - """ - # Create user who will be invited & join room - user_id = self.register_user("beep", "test") - user_tok = self.login("beep", "test") - - time_before_room = int(self.clock.time_msec()) - - # Create a room with some events - room_id, _, _, _ = self._create_test_room() - # Invite the user - self.helper.invite( - room_id, src=self.appservice.sender, tok=self.appservice.token, targ=user_id - ) - - # Create another room, send a bunch of events to advance the stream token - other_room_id = self.helper.create_room_as( - self.appservice.sender, tok=self.appservice.token - ) - for _ in range(5): - self.helper.send_event( - room_id=other_room_id, - type=EventTypes.Message, - content={"msgtype": "m.text", "body": "C"}, - tok=self.appservice.token, - ) - - # Join the room as the normal user - self.helper.join(room_id, user_id, tok=user_tok) - - # Create an event to hang the historical batch from - In order to see - # the failure case originally reported in #12281, the historical batch - # must be hung from the most recent event in the room so the base - # insertion event ends up with the highest `topogological_ordering` - # (`depth`) in the room but will have a negative `stream_ordering` - # because it's a `historical` event. Previously, when assembling the - # `state` for the `/sync` response, the bugged logic would sort by - # `topological_ordering` descending and pick up the base insertion - # event because it has a negative `stream_ordering` below the given - # pagination token. Now we properly sort by `stream_ordering` - # descending which puts `historical` events with a negative - # `stream_ordering` way at the bottom and aren't selected as expected. - response = self.helper.send_event( - room_id=room_id, - type=EventTypes.Message, - content={ - "msgtype": "m.text", - "body": "C", - }, - tok=self.appservice.token, - ) - event_to_hang_id = response["event_id"] - - channel = self.make_request( - "POST", - "/_matrix/client/unstable/org.matrix.msc2716/rooms/%s/batch_send?prev_event_id=%s" - % (room_id, event_to_hang_id), - content={ - "events": _create_message_events_for_batch_send_request( - self.virtual_user_id, time_before_room, 3 - ), - "state_events_at_start": _create_join_state_events_for_batch_send_request( - [self.virtual_user_id], time_before_room - ), - }, - access_token=self.appservice.token, - ) - self.assertEqual(channel.code, 200, channel.result) - - # Now we need to find the invite + join events stream tokens so we can sync between - main_store = self.hs.get_datastores().main - events, next_key = self.get_success( - main_store.get_recent_events_for_room( - room_id, - 50, - end_token=main_store.get_room_max_token(), - ), - ) - invite_event_position = None - for event in events: - if ( - event.type == "m.room.member" - and event.content["membership"] == "invite" - ): - invite_event_position = self.get_success( - main_store.get_topological_token_for_event(event.event_id) - ) - break - - assert invite_event_position is not None, "No invite event found" - - # Remove the topological order from the token by re-creating w/stream only - invite_event_position = RoomStreamToken(None, invite_event_position.stream) - - # Sync everything after this token - since_token = self.get_success(invite_event_position.to_string(main_store)) - sync_response = self.make_request( - "GET", - f"/sync?since={since_token}", - access_token=user_tok, - ) - - # Assert that, for this room, the user was considered to have joined and thus - # receives the full state history - state_event_types = [ - event["type"] - for event in sync_response.json_body["rooms"]["join"][room_id]["state"][ - "events" - ] - ] - - assert ( - "m.room.create" in state_event_types - ), "Missing room full state in sync response" diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 4d39c89f6f..aaa4f3bba0 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -20,7 +20,7 @@ import json from http import HTTPStatus from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from unittest.mock import Mock, call, patch +from unittest.mock import AsyncMock, Mock, call, patch from urllib import parse as urlparse from parameterized import param, parameterized @@ -41,7 +41,6 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.appservice import ApplicationService from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, register, room, sync from synapse.server import HomeServer @@ -52,7 +51,6 @@ from synapse.util.stringutils import random_string from tests import unittest from tests.http.server._base import make_request_with_cancellation_test from tests.storage.test_stream import PaginationTestCase -from tests.test_utils import make_awaitable from tests.test_utils.event_injection import create_event from tests.unittest import override_config @@ -67,19 +65,17 @@ class RoomBase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.hs = self.setup_test_homeserver( "red", - federation_http_client=None, - federation_client=Mock(), ) - self.hs.get_federation_handler = Mock() # type: ignore[assignment] - self.hs.get_federation_handler.return_value.maybe_backfill = Mock( - return_value=make_awaitable(None) + self.hs.get_federation_handler = Mock() # type: ignore[method-assign] + self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock( + return_value=None ) async def _insert_client_ip(*args: Any, **kwargs: Any) -> None: return None - self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[assignment] + self.hs.get_datastores().main.insert_client_ip = _insert_client_ip # type: ignore[method-assign] return self.hs @@ -713,7 +709,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(30, channel.resource_usage.db_txn_count) + self.assertEqual(32, channel.resource_usage.db_txn_count) def test_post_room_initial_state(self) -> None: # POST with initial_state config key, expect new room id @@ -726,7 +722,7 @@ class RoomsCreateTestCase(RoomBase): self.assertEqual(HTTPStatus.OK, channel.code, channel.result) self.assertTrue("room_id" in channel.json_body) assert channel.resource_usage is not None - self.assertEqual(32, channel.resource_usage.db_txn_count) + self.assertEqual(34, channel.resource_usage.db_txn_count) def test_post_room_visibility_key(self) -> None: # POST with visibility config key, expect new room id @@ -1364,7 +1360,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) def test_send_state_event_ts(self) -> None: """Test sending a state event with a custom timestamp.""" @@ -1386,7 +1382,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) def test_send_membership_event_ts(self) -> None: """Test sending a membership event with a custom timestamp.""" @@ -1408,7 +1404,7 @@ class RoomAppserviceTsParamTestCase(unittest.HomeserverTestCase): # Ensure the event was persisted with the correct timestamp. res = self.get_success(self.main_store.get_event(event_id)) - self.assertEquals(ts, res.origin_server_ts) + self.assertEqual(ts, res.origin_server_ts) class RoomJoinRatelimitTestCase(RoomBase): @@ -1451,6 +1447,30 @@ class RoomJoinRatelimitTestCase(RoomBase): @unittest.override_config( {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} ) + def test_join_attempts_local_ratelimit(self) -> None: + """Tests that unsuccessful joins that end up being denied are rate-limited.""" + # Create 4 rooms + room_ids = [ + self.helper.create_room_as(self.user_id, is_public=True) for _ in range(4) + ] + # Pre-emptively ban the user who will attempt to join. + joiner_user_id = self.register_user("joiner", "secret") + for room_id in room_ids: + self.helper.ban(room_id, self.user_id, joiner_user_id) + + # Now make a new user try to join some of them. + # The user can make 3 requests, each of which should be denied. + for room_id in room_ids[0:3]: + self.helper.join(room_id, joiner_user_id, expect_code=HTTPStatus.FORBIDDEN) + + # The fourth attempt should be rate limited. + self.helper.join( + room_ids[3], joiner_user_id, expect_code=HTTPStatus.TOO_MANY_REQUESTS + ) + + @unittest.override_config( + {"rc_joins": {"local": {"per_second": 0.5, "burst_count": 3}}} + ) def test_join_local_ratelimit_profile_change(self) -> None: """Tests that sending a profile update into all of the user's joined rooms isn't rate-limited by the rate-limiter on joins.""" @@ -1941,6 +1961,43 @@ class RoomPowerLevelOverridesInPracticeTestCase(RoomBase): channel.json_body["error"], ) + @unittest.override_config( + { + "default_power_level_content_override": { + "private_chat": { + "events": { + "m.room.avatar": 50, + "m.room.canonical_alias": 50, + "m.room.encryption": 999, + "m.room.history_visibility": 100, + "m.room.name": 50, + "m.room.power_levels": 100, + "m.room.server_acl": 100, + "m.room.tombstone": 100, + }, + "events_default": 0, + }, + } + }, + ) + def test_config_override_blocks_encrypted_room(self) -> None: + # Given the server has config for private_chats, + + # When I attempt to create an encrypted private_chat room + channel = self.make_request( + "POST", + "/createRoom", + '{"creation_content": {"m.federate": false},"name": "Secret Private Room","preset": "private_chat","initial_state": [{"type": "m.room.encryption","state_key": "","content": {"algorithm": "m.megolm.v1.aes-sha2"}}]}', + ) + + # Then I am not allowed because the required power level is unattainable + self.assertEqual(HTTPStatus.FORBIDDEN, channel.code, msg=channel.result["body"]) + self.assertEqual( + "You cannot create an encrypted room. " + + "user_level (100) < send_level (999)", + channel.json_body["error"], + ) + class RoomInitialSyncTestCase(RoomBase): """Tests /rooms/$room_id/initialSync.""" @@ -2052,11 +2109,8 @@ class RoomMessageListTestCase(RoomBase): self.assertEqual(len(chunk), 2, [event["content"] for event in chunk]) # Purge every event before the second event. - purge_id = random_string(16) - pagination_handler._purges_by_id[purge_id] = PurgeStatus() self.get_success( - pagination_handler._purge_history( - purge_id=purge_id, + pagination_handler.purge_history( room_id=self.room_id, token=second_token_str, delete_local_events=True, @@ -2340,7 +2394,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): ] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - return self.setup_test_homeserver(federation_client=Mock()) + return self.setup_test_homeserver(federation_client=AsyncMock()) def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.register_user("user", "pass") @@ -2350,7 +2404,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] + self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined] search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -2378,7 +2432,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""), - make_awaitable({}), + {}, ) search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"} @@ -3378,17 +3432,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] + return_value=None, ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None) + mock = AsyncMock(return_value=True, spec=lambda *x: None) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( mock ) @@ -3416,7 +3470,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Now change the return value of the callback to deny any invite and test that # we can't send the invite. - mock.return_value = make_awaitable(False) + mock.return_value = False channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3442,18 +3496,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Mock a few functions to prevent the test from failing due to failing to talk to # a remote IS. We keep the mock for make_and_store_3pid_invite around so we # can check its call_count later on during the test. - make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0))) - self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment] - self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0)) + self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[method-assign] + self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[method-assign] + return_value=None, ) # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it # allow everything for now. # `spec` argument is needed for this function mock to have `__qualname__`, which # is needed for `Measure` metrics buried in SpamChecker. - mock = Mock( - return_value=make_awaitable(synapse.module_api.NOT_SPAM), + mock = AsyncMock( + return_value=synapse.module_api.NOT_SPAM, spec=lambda *x: None, ) self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append( @@ -3484,7 +3538,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): # Now change the return value of the callback to deny any invite and test that # we can't send the invite. We pick an arbitrary error code to be able to check # that the same code has been returned - mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN) + mock.return_value = Codes.CONSENT_NOT_GIVEN channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", @@ -3503,7 +3557,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase): make_invite_mock.assert_called_once() # Run variant with `Tuple[Codes, dict]`. - mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"})) + mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"}) channel = self.make_request( method="POST", path="/rooms/" + self.room_id + "/invite", diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py index 8d2cdf8751..9aecf88e41 100644 --- a/tests/rest/client/test_shadow_banned.py +++ b/tests/rest/client/test_shadow_banned.py @@ -84,7 +84,7 @@ class RoomTestCase(_ShadowBannedBase): def test_invite_3pid(self) -> None: """Ensure that a 3PID invite does not attempt to contact the identity server.""" identity_handler = self.hs.get_identity_handler() - identity_handler.lookup_3pid = Mock( # type: ignore[assignment] + identity_handler.lookup_3pid = Mock( # type: ignore[method-assign] side_effect=AssertionError("This should not get called") ) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 9c876c7a32..d60665254e 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from http import HTTPStatus -from typing import List, Optional +from typing import List from parameterized import parameterized @@ -22,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( - EduTypes, EventContentFields, EventTypes, ReceiptTypes, @@ -376,156 +374,6 @@ class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin): ) -class ReadReceiptsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - receipts.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Join the second user - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - - def test_private_read_receipts(self) -> None: - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a private read receipt to tell the server the first user's message was read - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that the first user can't see the other user's private read receipt - self.assertIsNone(self._get_read_receipt()) - - def test_public_receipt_can_override_private(self) -> None: - """ - Sending a public read receipt to the same event which has a private read - receipt should cause that receipt to become public. - """ - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a private read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - self.assertIsNone(self._get_read_receipt()) - - # Send a public read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that we did override the private read receipt - self.assertNotEqual(self._get_read_receipt(), None) - - def test_private_receipt_cannot_override_public(self) -> None: - """ - Sending a private read receipt to the same event which has a public read - receipt should cause no change. - """ - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a public read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - self.assertNotEqual(self._get_read_receipt(), None) - - # Send a private read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that we didn't override the public read receipt - self.assertIsNone(self._get_read_receipt()) - - def test_read_receipt_with_empty_body_is_rejected(self) -> None: - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt for this message with an empty body - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", - access_token=self.tok2, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) - self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) - - def _get_read_receipt(self) -> Optional[JsonDict]: - """Syncs and returns the read receipt.""" - - # Checks if event is a read receipt - def is_read_receipt(event: JsonDict) -> bool: - return event["type"] == EduTypes.RECEIPT - - # Sync - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - if channel.json_body.get("rooms", None) is None: - return None - - # Return the read receipt - ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ - "ephemeral" - ]["events"] - receipt_event = filter(is_read_receipt, ephemeral_events) - return next(receipt_event, None) - - class UnreadMessagesTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index e5ba5a9706..57eb713b15 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -13,7 +13,7 @@ # limitations under the License. import threading from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -33,7 +33,6 @@ from synapse.util import Clock from synapse.util.frozenutils import unfreeze from tests import unittest -from tests.test_utils import make_awaitable if TYPE_CHECKING: from synapse.module_api import ModuleApi @@ -118,7 +117,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): async def _check_event_auth(origin: Any, event: Any, context: Any) -> None: pass - hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment] + hs.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign] return hs @@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): def test_on_new_event(self) -> None: """Test that the on_new_event callback is called on new events""" - on_new_event = Mock(make_awaitable(None)) + on_new_event = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append( on_new_event ) @@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo" # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( m ) @@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(None)) + deactivation_mock = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append( deactivation_mock, @@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): # Also register a mocked callback for profile updates, to check that the # deactivation code calls it in a way that let modules know the user is being # deactivated. - profile_mock = Mock(return_value=make_awaitable(None)) + profile_mock = AsyncMock(return_value=None) self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append( profile_mock, ) @@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): well as a reactivation. """ # Register a mock callback. - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_user_deactivation_status_changed_callbacks.append(m) @@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): correctly when processing a user's deactivation. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(False)) + deactivation_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): correctly when processing a user's deactivation triggered by a server admin. """ # Register a mocked callback. - deactivation_mock = Mock(return_value=make_awaitable(False)) + deactivation_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_deactivate_user_callbacks.append( deactivation_mock, @@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): correctly when processing an admin's shutdown room request. """ # Register a mocked callback. - shutdown_mock = Mock(return_value=make_awaitable(False)) + shutdown_mock = AsyncMock(return_value=False) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._check_can_shutdown_room_callbacks.append( shutdown_mock, @@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): associating a 3PID to an account. """ # Register a mocked callback. - threepid_bind_mock = Mock(return_value=make_awaitable(None)) + threepid_bind_mock = AsyncMock(return_value=None) third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock) @@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): just before associating and removing a 3PID to/from an account. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_add_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) - ) - on_remove_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) + on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None) + on_remove_user_third_party_identifier_callback_mock = AsyncMock( + return_value=None ) self.hs.get_module_api().register_third_party_rules_callbacks( on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock, @@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase): when a user is deactivated and their third-party ID associations are deleted. """ # Pretend to be a Synapse module and register both callbacks as mocks. - on_remove_user_third_party_identifier_callback_mock = Mock( - return_value=make_awaitable(None) + on_remove_user_third_party_identifier_callback_mock = AsyncMock( + return_value=None ) self.hs.get_module_api().register_third_party_rules_callbacks( on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock, diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index d8dc56261a..951a3cbc43 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -14,7 +14,7 @@ from http import HTTPStatus from typing import Any, Generator, Tuple, cast -from unittest.mock import Mock, call +from unittest.mock import AsyncMock, Mock, call from twisted.internet import defer, reactor as _reactor @@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.utils import MockClock reactor = cast(ISynapseReactor, _reactor) @@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test_executes_given_function( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) res = yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg" ) @@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): def test_deduplicates_based_on_key( self, ) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute_request( self.mock_request, @@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]: - cb = Mock(return_value=make_awaitable(self.mock_http_response)) + cb = AsyncMock(return_value=self.mock_http_response) yield self.cache.fetch_or_execute_request( self.mock_request, self.mock_requester, cb, "an arg" ) diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index 9532e5ddc1..465b696c0b 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -37,7 +37,6 @@ import attr from typing_extensions import Literal from twisted.test.proto_helpers import MemoryReactorClock -from twisted.web.resource import Resource from twisted.web.server import Site from synapse.api.constants import Membership @@ -45,7 +44,7 @@ from synapse.api.errors import Codes from synapse.server import HomeServer from synapse.types import JsonDict -from tests.server import FakeChannel, FakeSite, make_request +from tests.server import FakeChannel, make_request from tests.test_utils.html_parsers import TestHtmlParser from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer @@ -558,7 +557,6 @@ class RestHelper: def upload_media( self, - resource: Resource, image_data: bytes, tok: str, filename: str = "test.png", @@ -576,7 +574,7 @@ class RestHelper: path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( self.reactor, - FakeSite(resource, self.reactor), + self.site, "POST", path, content=image_data, diff --git a/tests/rest/media/test_url_preview.py b/tests/rest/media/test_url_preview.py index 170fb0534a..24459c6af4 100644 --- a/tests/rest/media/test_url_preview.py +++ b/tests/rest/media/test_url_preview.py @@ -24,10 +24,10 @@ from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IAddress, IResolutionReceiver from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor +from twisted.web.resource import Resource from synapse.config.oembed import OEmbedEndpointConfig from synapse.media.url_previewer import IMAGE_CACHE_EXPIRY_MS -from synapse.rest.media.media_repository_resource import MediaRepositoryResource from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -40,7 +40,7 @@ from tests.test_utils import SMALL_PNG try: import lxml except ImportError: - lxml = None + lxml = None # type: ignore[assignment] class URLPreviewTests(unittest.HomeserverTestCase): @@ -117,8 +117,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.media_repo = hs.get_media_repository() - media_repo_resource = hs.get_media_repository_resource() - self.preview_url = media_repo_resource.children[b"preview_url"] + assert self.media_repo.url_previewer is not None + self.url_previewer = self.media_repo.url_previewer self.lookups: Dict[str, Any] = {} @@ -143,8 +143,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.reactor.nameResolver = Resolver() # type: ignore[assignment] - def create_test_resource(self) -> MediaRepositoryResource: - return self.hs.get_media_repository_resource() + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server + + A resource tree is a mapping from path to twisted.web.resource. + + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + return {"/_matrix/media": self.hs.get_media_repository_resource()} def _assert_small_png(self, json_body: JsonDict) -> None: """Assert properties from the SMALL_PNG test image.""" @@ -159,7 +166,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -183,7 +190,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check the cache returns the correct response channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://matrix.org", + shorthand=False, ) # Check the cache response has the same content @@ -193,13 +202,15 @@ class URLPreviewTests(unittest.HomeserverTestCase): ) # Clear the in-memory cache - self.assertIn("http://matrix.org", self.preview_url._url_previewer._cache) - self.preview_url._url_previewer._cache.pop("http://matrix.org") - self.assertNotIn("http://matrix.org", self.preview_url._url_previewer._cache) + self.assertIn("http://matrix.org", self.url_previewer._cache) + self.url_previewer._cache.pop("http://matrix.org") + self.assertNotIn("http://matrix.org", self.url_previewer._cache) # Check the database cache returns the correct response channel = self.make_request( - "GET", "preview_url?url=http://matrix.org", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://matrix.org", + shorthand=False, ) # Check the cache response has the same content @@ -221,7 +232,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -251,7 +262,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -287,7 +298,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -328,7 +339,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -363,7 +374,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -396,7 +407,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -425,7 +436,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) # No requests made. @@ -446,7 +459,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) @@ -463,7 +478,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blocked IP addresses, accessed directly, are not spidered. """ channel = self.make_request( - "GET", "preview_url?url=http://192.168.1.1", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://192.168.1.1", + shorthand=False, ) # No requests made. @@ -479,7 +496,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): Blocked IP ranges, accessed directly, are not spidered. """ channel = self.make_request( - "GET", "preview_url?url=http://1.1.1.2", shorthand=False + "GET", "/_matrix/media/v3/preview_url?url=http://1.1.1.2", shorthand=False ) self.assertEqual(channel.code, 403) @@ -497,7 +514,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -533,7 +550,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) self.assertEqual( @@ -553,7 +572,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): ] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) # No requests made. @@ -574,7 +595,9 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.lookups["example.com"] = [(IPv6Address, "2001:800::1")] channel = self.make_request( - "GET", "preview_url?url=http://example.com", shorthand=False + "GET", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) self.assertEqual(channel.code, 502) @@ -591,10 +614,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): OPTIONS returns the OPTIONS. """ channel = self.make_request( - "OPTIONS", "preview_url?url=http://example.com", shorthand=False + "OPTIONS", + "/_matrix/media/v3/preview_url?url=http://example.com", + shorthand=False, ) - self.assertEqual(channel.code, 200) - self.assertEqual(channel.json_body, {}) + self.assertEqual(channel.code, 204) def test_accept_language_config_option(self) -> None: """ @@ -605,7 +629,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Build and make a request to the server channel = self.make_request( "GET", - "preview_url?url=http://example.com", + "/_matrix/media/v3/preview_url?url=http://example.com", shorthand=False, await_result=False, ) @@ -658,7 +682,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -708,7 +732,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -750,7 +774,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -790,7 +814,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -831,7 +855,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"preview_url?{query_params}", + f"/_matrix/media/v3/preview_url?{query_params}", shorthand=False, ) self.pump() @@ -852,7 +876,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://matrix.org", + "/_matrix/media/v3/preview_url?url=http://matrix.org", shorthand=False, await_result=False, ) @@ -889,7 +913,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -949,7 +973,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -998,7 +1022,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.hulu.com/watch/12345", + "/_matrix/media/v3/preview_url?url=http://www.hulu.com/watch/12345", shorthand=False, await_result=False, ) @@ -1043,7 +1067,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1072,7 +1096,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1164,7 +1188,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", + "/_matrix/media/v3/preview_url?url=http://www.twitter.com/matrixdotorg/status/12345", shorthand=False, await_result=False, ) @@ -1205,7 +1229,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=http://cdn.twitter.com/matrixdotorg", + "/_matrix/media/v3/preview_url?url=http://cdn.twitter.com/matrixdotorg", shorthand=False, await_result=False, ) @@ -1247,7 +1271,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check fetching channel = self.make_request( "GET", - f"download/{host}/{media_id}", + f"/_matrix/media/v3/download/{host}/{media_id}", shorthand=False, await_result=False, ) @@ -1260,7 +1284,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"download/{host}/{media_id}", + f"/_matrix/media/v3/download/{host}/{media_id}", shorthand=False, await_result=False, ) @@ -1295,7 +1319,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): # Check fetching channel = self.make_request( "GET", - f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", shorthand=False, await_result=False, ) @@ -1313,7 +1337,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + f"/_matrix/media/v3/thumbnail/{host}/{media_id}?width=32&height=32&method=scale", shorthand=False, await_result=False, ) @@ -1343,7 +1367,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertTrue(os.path.isdir(thumbnail_dir)) self.reactor.advance(IMAGE_CACHE_EXPIRY_MS * 1000 + 1) - self.get_success(self.preview_url._url_previewer._expire_url_cache_data()) + self.get_success(self.url_previewer._expire_url_cache_data()) for path in [file_path] + file_dirs + [thumbnail_dir] + thumbnail_dirs: self.assertFalse( @@ -1363,7 +1387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + bad_url, + "/_matrix/media/v3/preview_url?url=" + bad_url, shorthand=False, await_result=False, ) @@ -1372,7 +1396,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + good_url, + "/_matrix/media/v3/preview_url?url=" + good_url, shorthand=False, await_result=False, ) @@ -1404,7 +1428,7 @@ class URLPreviewTests(unittest.HomeserverTestCase): channel = self.make_request( "GET", - "preview_url?url=" + bad_url, + "/_matrix/media/v3/preview_url?url=" + bad_url, shorthand=False, await_result=False, ) diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 2091b08d89..377243a170 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -17,6 +17,13 @@ from synapse.rest.well_known import well_known_resource from tests import unittest +try: + import authlib # noqa: F401 + + HAS_AUTHLIB = True +except ImportError: + HAS_AUTHLIB = False + class WellKnownTests(unittest.HomeserverTestCase): def create_test_resource(self) -> Resource: @@ -96,3 +103,37 @@ class WellKnownTests(unittest.HomeserverTestCase): "GET", "/.well-known/matrix/server", shorthand=False ) self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @unittest.override_config( + { + "public_baseurl": "https://homeserver", # this is only required so that client well known is served + "experimental_features": { + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "account_management_url": "https://my-account.issuer", + "client_id": "id", + "client_auth_method": "client_secret_post", + "client_secret": "secret", + }, + }, + "disable_registration": True, + } + ) + def test_client_well_known_msc3861_oauth_delegation(self) -> None: + channel = self.make_request( + "GET", "/.well-known/matrix/client", shorthand=False + ) + + self.assertEqual(channel.code, 200) + self.assertEqual( + channel.json_body, + { + "m.homeserver": {"base_url": "https://homeserver/"}, + "org.matrix.msc2965.authentication": { + "issuer": "https://issuer", + "account": "https://my-account.issuer", + }, + }, + ) diff --git a/tests/server.py b/tests/server.py index 7296f0a552..08633fe640 100644 --- a/tests/server.py +++ b/tests/server.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import hashlib +import ipaddress import json import logging import os @@ -26,6 +27,7 @@ from typing import ( Any, Awaitable, Callable, + Deque, Dict, Iterable, List, @@ -41,10 +43,10 @@ from typing import ( from unittest.mock import Mock import attr -from typing_extensions import Deque, ParamSpec +from typing_extensions import ParamSpec from zope.interface import implementer -from twisted.internet import address, threads, udp +from twisted.internet import address, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed from twisted.internet.error import DNSLookupError @@ -53,6 +55,7 @@ from twisted.internet.interfaces import ( IConnector, IConsumer, IHostnameResolver, + IListeningPort, IProducer, IProtocol, IPullProducer, @@ -62,7 +65,7 @@ from twisted.internet.interfaces import ( IResolverSimple, ITransport, ) -from twisted.internet.protocol import ClientFactory, DatagramProtocol +from twisted.internet.protocol import ClientFactory, DatagramProtocol, Factory from twisted.python import threadpool from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock @@ -523,6 +526,35 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): """ self._tcp_callbacks[(host, port)] = callback + def connectUNIX( + self, + address: str, + factory: ClientFactory, + timeout: float = 30, + checkPID: int = 0, + ) -> IConnector: + """ + Unix sockets aren't supported for unit tests yet. Make it obvious to any + developer trying it out that they will need to do some work before being able + to use it in tests. + """ + raise Exception("Unix sockets are not implemented for tests yet, sorry.") + + def listenUNIX( + self, + address: str, + factory: Factory, + backlog: int = 50, + mode: int = 0o666, + wantPID: int = 0, + ) -> IListeningPort: + """ + Unix sockets aren't supported for unit tests yet. Make it obvious to any + developer trying it out that they will need to do some work before being able + to use it in tests. + """ + raise Exception("Unix sockets are not implemented for tests, sorry") + def connectTCP( self, host: str, @@ -536,6 +568,8 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): conn = super().connectTCP( host, port, factory, timeout=timeout, bindAddress=None ) + if self.lookups and host in self.lookups: + validate_connector(conn, self.lookups[host]) callback = self._tcp_callbacks.get((host, port)) if callback: @@ -568,6 +602,55 @@ class ThreadedMemoryReactorClock(MemoryReactorClock): super().advance(0) +def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: + """Try to validate the obtained connector as it would happen when + synapse is running and the conection will be established. + + This method will raise a useful exception when necessary, else it will + just do nothing. + + This is in order to help catch quirks related to reactor.connectTCP, + since when called directly, the connector's destination will be of type + IPv4Address, with the hostname as the literal host that was given (which + could be an IPv6-only host or an IPv6 literal). + + But when called from reactor.connectTCP *through* e.g. an Endpoint, the + connector's destination will contain the specific IP address with the + correct network stack class. + + Note that testing code paths that use connectTCP directly should not be + affected by this check, unless they specifically add a test with a + matching reactor.lookups[HOSTNAME] = "IPv6Literal", where reactor is of + type ThreadedMemoryReactorClock. + For an example of implementing such tests, see test/handlers/send_email.py. + """ + destination = connector.getDestination() + + # We use address.IPv{4,6}Address to check what the reactor thinks it is + # is sending but check for validity with ipaddress.IPv{4,6}Address + # because they fail with IPs on the wrong network stack. + cls_mapping = { + address.IPv4Address: ipaddress.IPv4Address, + address.IPv6Address: ipaddress.IPv6Address, + } + + cls = cls_mapping.get(destination.__class__) + + if cls is not None: + try: + cls(expected_ip) + except Exception as exc: + raise ValueError( + "Invalid IP type and resolution for %s. Expected %s to be %s" + % (destination, expected_ip, cls.__name__) + ) from exc + else: + raise ValueError( + "Unknown address type %s for %s" + % (destination.__class__.__name__, destination) + ) + + class ThreadPool: """ Threadless thread pool. @@ -639,10 +722,10 @@ def _make_test_homeserver_synchronous(server: HomeServer) -> None: **kwargs, ) - pool.runWithConnection = runWithConnection # type: ignore[assignment] + pool.runWithConnection = runWithConnection # type: ignore[method-assign] pool.runInteraction = runInteraction # type: ignore[assignment] # Replace the thread pool with a threadless 'thread' pool - pool.threadpool = ThreadPool(clock._reactor) # type: ignore[assignment] + pool.threadpool = ThreadPool(clock._reactor) pool.running = True # We've just changed the Databases to run DB transactions on the same @@ -969,8 +1052,6 @@ def setup_test_homeserver( hs.tls_server_context_factory = Mock() hs.setup() - if homeserver_to_use == TestHomeServer: - hs.setup_background_tasks() if isinstance(db_engine, PostgresEngine): database_pool = hs.get_datastores().databases[0] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index d2bfa53eda..17f428bfc5 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Tuple -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -29,7 +29,6 @@ from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import override_config from tests.utils import default_config @@ -69,24 +68,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): assert isinstance(rlsn, ResourceLimitsServerNotices) self._rlsn = rlsn - self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(1000) - ) - self._rlsn._server_notices_manager.send_notice = Mock( # type: ignore[assignment] - return_value=make_awaitable(Mock()) + self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=1000) + self._rlsn._server_notices_manager.send_notice = AsyncMock( # type: ignore[method-assign] + return_value=Mock() ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" - self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=make_awaitable("!something:localhost") + self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = ( + AsyncMock(return_value="!something:localhost") ) - self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = Mock( - return_value=make_awaitable("!something:localhost") + self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = AsyncMock( + return_value="!something:localhost" ) - self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) # type: ignore[assignment] + self._rlsn._store.add_tag_to_room = AsyncMock(return_value=None) # type: ignore[method-assign] + self._rlsn._store.get_tags_for_room = AsyncMock(return_value={}) # type: ignore[method-assign] @override_config({"hs_disabled": True}) def test_maybe_send_server_notice_disabled_hs(self) -> None: @@ -103,14 +100,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self) -> None: """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) # Would be better to check the content, but once == remove blocking event @@ -125,16 +122,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user has blocked notice, but notice ought to be there (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -145,8 +142,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, but should have one """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -158,8 +155,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -171,12 +168,10 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None) - ) - self._rlsn._store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(None) + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None ) + self._rlsn._store.user_last_seen_monthly_active = AsyncMock(return_value=None) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) self._send_notice.assert_not_called() @@ -189,8 +184,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when server is over MAU limit and alerting is suppressed, then an alert message is not sent into the room """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -204,8 +199,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test that when a server is disabled, that MAU limit alerting is ignored. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -223,22 +218,22 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): When the room is already in a blocked state, test that when alerting is suppressed that the room is returned to an unblocked state. """ - self._rlsn._auth_blocking.check_auth_blocking = Mock( # type: ignore[assignment] - return_value=make_awaitable(None), + self._rlsn._auth_blocking.check_auth_blocking = AsyncMock( # type: ignore[method-assign] + return_value=None, side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) - self._rlsn._is_room_currently_blocked = Mock( # type: ignore[assignment] - return_value=make_awaitable((True, [])) + self._rlsn._is_room_currently_blocked = AsyncMock( # type: ignore[method-assign] + return_value=(True, []) ) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) - self._rlsn._store.get_events = Mock( # type: ignore[assignment] - return_value=make_awaitable({"123": mock_event}) + self._rlsn._store.get_events = AsyncMock( # type: ignore[method-assign] + return_value={"123": mock_event} ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -284,11 +279,9 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): self.user_id = "@user_id:test" def test_server_notice_only_sent_once(self) -> None: - self.store.get_monthly_active_count = Mock(return_value=make_awaitable(1000)) + self.store.get_monthly_active_count = AsyncMock(return_value=1000) - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(1000) - ) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=1000) # Call the function multiple times to ensure we only send the notice once self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -327,7 +320,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): hasn't been reached (since it's the only user and the limit is 5), so users shouldn't receive a server notice. """ - m = Mock(return_value=make_awaitable(None)) + m = AsyncMock(return_value=None) self._rlsn._server_notices_manager.maybe_get_notice_room_for_user = m user_id = self.register_user("user", "password") diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 2e3f2318d9..6a2f7584f6 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -719,7 +719,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase): persisted_events = {a.event_id: a, b.event_id: b} unpersited_events = {c.event_id: c} - state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}] + state_sets = [ + {("a", ""): a.event_id, ("b", ""): b.event_id}, + {("c", ""): c.event_id}, + ] store = TestStateResolutionStore(persisted_events) @@ -774,8 +777,8 @@ class AuthChainDifferenceTestCase(unittest.TestCase): unpersited_events = {c.event_id: c, d.event_id: d} state_sets = [ - {"a": a.event_id, "b": b.event_id}, - {"c": c.event_id, "d": d.event_id}, + {("a", ""): a.event_id, ("b", ""): b.event_id}, + {("c", ""): c.event_id, ("d", ""): d.event_id}, ] store = TestStateResolutionStore(persisted_events) @@ -841,8 +844,8 @@ class AuthChainDifferenceTestCase(unittest.TestCase): unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e} state_sets = [ - {"a": a.event_id, "b": b.event_id, "e": e.event_id}, - {"c": c.event_id, "d": d.event_id}, + {("a", ""): a.event_id, ("b", ""): b.event_id, ("e", ""): e.event_id}, + {("c", ""): c.event_id, ("d", ""): d.event_id}, ] store = TestStateResolutionStore(persisted_events) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 9606ecc43b..b223dc750b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + def test_persisting_event_prefills_get_event_cache(self) -> None: + """ + Test to make sure that the `_get_event_cache` is prefilled after we persist an + event and returns the updated value. + """ + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "conflabulation"}, + ) + ) + + # First, check `_get_event_cache` for the event we just made + # to verify it's not in the cache. + res = self.store._get_event_cache.get_local((event.event_id,)) + self.assertEqual(res, None, "Event was cached when it should not have been.") + + with LoggingContext(name="test") as ctx: + # Persist the event which should invalidate then prefill the + # `_get_event_cache` so we don't return stale values. + # Side Note: Apparently, persisting an event isn't a transaction in the + # sense that it is recorded in the LoggingContext + persistence = self.hs.get_storage_controllers().persistence + assert persistence is not None + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + # Check `_get_event_cache` again and we should see the updated fact + # that we now have the event cached after persisting it. + res = self.store._get_event_cache.get_local((event.event_id,)) + self.assertEqual(res.event, event, "Event not cached as expected.") # type: ignore + + # Try and fetch the event from the database. + self.get_success(self.store.get_event(event.event_id)) + + # Verify that the database hit was avoided. + self.assertEqual( + ctx.get_resource_usage().evt_db_fetch_count, + 0, + "Database was hit, which would not happen if event was cached.", + ) + def test_invalidate_cache_by_room_id(self) -> None: """ Test to make sure that all events associated with the given `(room_id,)` @@ -188,7 +237,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.get_success(self.store._get_event_cache.clear()) + self.store._get_event_cache.clear() def test_simple(self) -> None: """Test that we cache events that we pull from the DB.""" @@ -205,7 +254,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): """ # Reset the event cache - self.get_success(self.store._get_event_cache.clear()) + self.store._get_event_cache.clear() with LoggingContext("test") as ctx: # We keep hold of the event event though we never use it. @@ -215,7 +264,7 @@ class EventCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) # Reset the event cache - self.get_success(self.store._get_event_cache.clear()) + self.store._get_event_cache.clear() with LoggingContext("test") as ctx: self.get_success(self.store.get_event(self.event_id)) @@ -390,7 +439,7 @@ class GetEventCancellationTestCase(unittest.HomeserverTestCase): self.event_id = res["event_id"] # Reset the event cache so the tests start with it empty - self.get_success(self.store._get_event_cache.clear()) + self.store._get_event_cache.clear() @contextmanager def blocking_get_event_calls( diff --git a/tests/storage/databases/main/test_lock.py b/tests/storage/databases/main/test_lock.py index 56cb49d9b5..35f77052a7 100644 --- a/tests/storage/databases/main/test_lock.py +++ b/tests/storage/databases/main/test_lock.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. + from twisted.internet import defer, reactor from twisted.internet.base import ReactorBase from twisted.internet.defer import Deferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS +from synapse.storage.databases.main.lock import _LOCK_TIMEOUT_MS, _RENEWAL_INTERVAL_MS from synapse.util import Clock from tests import unittest @@ -131,6 +132,7 @@ class LockTestCase(unittest.HomeserverTestCase): # We simulate the process getting stuck by cancelling the looping call # that keeps the lock active. + assert lock._looping_call lock._looping_call.stop() # Wait for the lock to timeout. @@ -166,4 +168,338 @@ class LockTestCase(unittest.HomeserverTestCase): # Now call the shutdown code self.get_success(self.store._on_shutdown()) - self.assertEqual(self.store._live_tokens, {}) + self.assertEqual(self.store._live_lock_tokens, {}) + + +class ReadWriteLockTestCase(unittest.HomeserverTestCase): + """Test the read/write lock implementation.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + def test_acquire_write_contention(self) -> None: + """Test that we can only acquire one write lock at a time""" + # Track the number of tasks holding the lock. + # Should be at most 1. + in_lock = 0 + max_in_lock = 0 + + release_lock: "Deferred[None]" = Deferred() + + async def task() -> None: + nonlocal in_lock + nonlocal max_in_lock + + lock = await self.store.try_acquire_read_write_lock( + "name", "key", write=True + ) + if not lock: + return + + async with lock: + in_lock += 1 + max_in_lock = max(max_in_lock, in_lock) + + # Block to allow other tasks to attempt to take the lock. + await release_lock + + in_lock -= 1 + + # Start 3 tasks. + task1 = defer.ensureDeferred(task()) + task2 = defer.ensureDeferred(task()) + task3 = defer.ensureDeferred(task()) + + # Give the reactor a kick so that the database transaction returns. + self.pump() + + release_lock.callback(None) + + # Run the tasks to completion. + # To work around `Linearizer`s using a different reactor to sleep when + # contended (#12841), we call `runUntilCurrent` on + # `twisted.internet.reactor`, which is a different reactor to that used + # by the homeserver. + assert isinstance(reactor, ReactorBase) + self.get_success(task1) + reactor.runUntilCurrent() + self.get_success(task2) + reactor.runUntilCurrent() + self.get_success(task3) + + # At most one task should have held the lock at a time. + self.assertEqual(max_in_lock, 1) + + def test_acquire_multiple_reads(self) -> None: + """Test that we can acquire multiple read locks at a time""" + # Track the number of tasks holding the lock. + in_lock = 0 + max_in_lock = 0 + + release_lock: "Deferred[None]" = Deferred() + + async def task() -> None: + nonlocal in_lock + nonlocal max_in_lock + + lock = await self.store.try_acquire_read_write_lock( + "name", "key", write=False + ) + if not lock: + return + + async with lock: + in_lock += 1 + max_in_lock = max(max_in_lock, in_lock) + + # Block to allow other tasks to attempt to take the lock. + await release_lock + + in_lock -= 1 + + # Start 3 tasks. + task1 = defer.ensureDeferred(task()) + task2 = defer.ensureDeferred(task()) + task3 = defer.ensureDeferred(task()) + + # Give the reactor a kick so that the database transaction returns. + self.pump() + + release_lock.callback(None) + + # Run the tasks to completion. + # To work around `Linearizer`s using a different reactor to sleep when + # contended (#12841), we call `runUntilCurrent` on + # `twisted.internet.reactor`, which is a different reactor to that used + # by the homeserver. + assert isinstance(reactor, ReactorBase) + self.get_success(task1) + reactor.runUntilCurrent() + self.get_success(task2) + reactor.runUntilCurrent() + self.get_success(task3) + + # At most one task should have held the lock at a time. + self.assertEqual(max_in_lock, 3) + + def test_write_lock_acquired(self) -> None: + """Test that we can take out a write lock and that while we hold it + nobody else can take it out. + """ + # First to acquire this lock, so it should complete + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + assert lock is not None + + # Enter the context manager + self.get_success(lock.__aenter__()) + + # Attempting to acquire the lock again fails, as both read and write. + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNone(lock2) + + lock3 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=False) + ) + self.assertIsNone(lock3) + + # Calling `is_still_valid` reports true. + self.assertTrue(self.get_success(lock.is_still_valid())) + + # Drop the lock + self.get_success(lock.__aexit__(None, None, None)) + + # We can now acquire the lock again. + lock4 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + assert lock4 is not None + self.get_success(lock4.__aenter__()) + self.get_success(lock4.__aexit__(None, None, None)) + + def test_read_lock_acquired(self) -> None: + """Test that we can take out a read lock and that while we hold it + only other reads can use it. + """ + # First to acquire this lock, so it should complete + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=False) + ) + assert lock is not None + + # Enter the context manager + self.get_success(lock.__aenter__()) + + # Attempting to acquire the write lock fails + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNone(lock2) + + # Attempting to acquire a read lock succeeds + lock3 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=False) + ) + assert lock3 is not None + self.get_success(lock3.__aenter__()) + + # Calling `is_still_valid` reports true. + self.assertTrue(self.get_success(lock.is_still_valid())) + + # Drop the first lock + self.get_success(lock.__aexit__(None, None, None)) + + # Attempting to acquire the write lock still fails, as lock3 is still + # active. + lock4 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNone(lock4) + + # Drop the still open third lock + self.get_success(lock3.__aexit__(None, None, None)) + + # We can now acquire the lock again. + lock5 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + assert lock5 is not None + self.get_success(lock5.__aenter__()) + self.get_success(lock5.__aexit__(None, None, None)) + + def test_maintain_lock(self) -> None: + """Test that we don't time out locks while they're still active (lock is + renewed in the background if the process is still alive)""" + + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + assert lock is not None + + self.get_success(lock.__aenter__()) + + # Wait for ages with the lock, we should not be able to get the lock. + for _ in range(10): + self.reactor.advance((_RENEWAL_INTERVAL_MS / 1000)) + + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNone(lock2) + + self.get_success(lock.__aexit__(None, None, None)) + + def test_timeout_lock(self) -> None: + """Test that we time out locks if they're not updated for ages""" + + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + assert lock is not None + + self.get_success(lock.__aenter__()) + + # We simulate the process getting stuck by cancelling the looping call + # that keeps the lock active. + assert lock._looping_call + lock._looping_call.stop() + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNotNone(lock2) + + self.assertFalse(self.get_success(lock.is_still_valid())) + + def test_drop(self) -> None: + """Test that dropping the context manager means we stop renewing the lock""" + + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNotNone(lock) + + del lock + + # Wait for the lock to timeout. + self.reactor.advance(2 * _LOCK_TIMEOUT_MS / 1000) + + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNotNone(lock2) + + def test_shutdown(self) -> None: + """Test that shutting down Synapse releases the locks""" + # Acquire two locks + lock = self.get_success( + self.store.try_acquire_read_write_lock("name", "key", write=True) + ) + self.assertIsNotNone(lock) + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name", "key2", write=True) + ) + self.assertIsNotNone(lock2) + + # Now call the shutdown code + self.get_success(self.store._on_shutdown()) + + self.assertEqual(self.store._live_read_write_lock_tokens, {}) + + def test_acquire_multiple_locks(self) -> None: + """Tests that acquiring multiple locks at once works.""" + + # Take out multiple locks and ensure that we can't get those locks out + # again. + lock = self.get_success( + self.store.try_acquire_multi_read_write_lock( + [("name1", "key1"), ("name2", "key2")], write=True + ) + ) + self.assertIsNotNone(lock) + + assert lock is not None + self.get_success(lock.__aenter__()) + + lock2 = self.get_success( + self.store.try_acquire_read_write_lock("name1", "key1", write=True) + ) + self.assertIsNone(lock2) + + lock3 = self.get_success( + self.store.try_acquire_read_write_lock("name2", "key2", write=False) + ) + self.assertIsNone(lock3) + + # Overlapping locks attempts will fail, and won't lock any locks. + lock4 = self.get_success( + self.store.try_acquire_multi_read_write_lock( + [("name1", "key1"), ("name3", "key3")], write=True + ) + ) + self.assertIsNone(lock4) + + lock5 = self.get_success( + self.store.try_acquire_read_write_lock("name3", "key3", write=True) + ) + self.assertIsNotNone(lock5) + assert lock5 is not None + self.get_success(lock5.__aenter__()) + self.get_success(lock5.__aexit__(None, None, None)) + + # Once we release the lock we can take out the locks again. + self.get_success(lock.__aexit__(None, None, None)) + + lock6 = self.get_success( + self.store.try_acquire_read_write_lock("name1", "key1", write=True) + ) + self.assertIsNotNone(lock6) + assert lock6 is not None + self.get_success(lock6.__aenter__()) + self.get_success(lock6.__aexit__(None, None, None)) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 5e1324a169..cbce26a725 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -15,7 +15,7 @@ import json import os import tempfile from typing import List, cast -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock import yaml @@ -35,12 +35,11 @@ from synapse.types import DeviceListUpdates from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(ApplicationServiceStoreTestCase, self).setUp() + super().setUp() self.as_yaml_files: List[str] = [] @@ -71,7 +70,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): except Exception: pass - super(ApplicationServiceStoreTestCase, self).tearDown() + super().tearDown() def _add_appservice( self, as_token: str, id: str, url: str, hs_token: str, sender: str @@ -110,7 +109,7 @@ class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase): class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(ApplicationServiceTransactionStoreTestCase, self).setUp() + super().setUp() self.as_yaml_files: List[str] = [] self.hs.config.appservice.app_service_config_files = self.as_yaml_files @@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase): # we aren't testing store._base stuff here, so mock this out # (ignore needed because Mypy won't allow us to assign to a method otherwise) - self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment] + self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[method-assign] self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events)) self.get_success(self._insert_txn(service.id, 10, events)) diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py index fd619b64d4..abf7d0564d 100644 --- a/tests/storage/test_background_update.py +++ b/tests/storage/test_background_update.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from unittest.mock import Mock +import logging +from unittest.mock import AsyncMock, Mock import yaml @@ -20,12 +20,18 @@ from twisted.internet.defer import Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer -from synapse.storage.background_updates import BackgroundUpdater +from synapse.storage.background_updates import ( + BackgroundUpdater, + ForeignKeyConstraint, + NotNullConstraint, + run_validate_constraint_and_delete_rows_schema_delta, +) +from synapse.storage.database import LoggingTransaction +from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import JsonDict from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable, simple_async_mock from tests.unittest import override_config @@ -324,6 +330,28 @@ class BackgroundUpdateTestCase(unittest.HomeserverTestCase): self.update_handler.side_effect = update_short self.get_success(self.updates.do_next_background_update(False)) + def test_failed_update_logs_exception_details(self) -> None: + needle = "RUH ROH RAGGY" + + def failing_update(progress: JsonDict, count: int) -> int: + raise Exception(needle) + + self.update_handler.side_effect = failing_update + self.update_handler.reset_mock() + + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + values={"update_name": "test_update", "progress_json": "{}"}, + ) + ) + + with self.assertLogs(level=logging.ERROR) as logs: + # Expect a back-to-back RuntimeError to be raised + self.get_failure(self.updates.run_background_updates(False), RuntimeError) + + self.assertTrue(any(needle in log for log in logs.output), logs.output) + class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -341,8 +369,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): # Mock out the AsyncContextManager class MockCM: - __aenter__ = simple_async_mock(return_value=None) - __aexit__ = simple_async_mock(return_value=None) + __aenter__ = AsyncMock(return_value=None) + __aexit__ = AsyncMock(return_value=None) self._update_ctx_manager = MockCM @@ -356,9 +384,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): # Register the callbacks with more mocks self.hs.get_module_api().register_background_update_controller_callbacks( on_update=self._on_update, - min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)), - default_batch_size=Mock( - return_value=make_awaitable(self._default_batch_size), + min_batch_size=AsyncMock(return_value=self._default_batch_size), + default_batch_size=AsyncMock( + return_value=self._default_batch_size, ), ) @@ -404,3 +432,225 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase): self.pump() self._update_ctx_manager.__aexit__.assert_called() self.get_success(do_update_d) + + +class BackgroundUpdateValidateConstraintTestCase(unittest.HomeserverTestCase): + """Tests the validate contraint and delete background handlers.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.updates: BackgroundUpdater = self.hs.get_datastores().main.db_pool.updates + # the base test class should have run the real bg updates for us + self.assertTrue( + self.get_success(self.updates.has_completed_background_updates()) + ) + + self.store = self.hs.get_datastores().main + + def test_not_null_constraint(self) -> None: + # Create the initial tables, where we have some invalid data. + """Tests adding a not null constraint.""" + table_sql = """ + CREATE TABLE test_constraint( + a INT PRIMARY KEY, + b INT + ); + """ + self.get_success( + self.store.db_pool.execute( + "test_not_null_constraint", lambda _: None, table_sql + ) + ) + + # We add an index so that we can check that its correctly recreated when + # using SQLite. + index_sql = "CREATE INDEX test_index ON test_constraint(a)" + self.get_success( + self.store.db_pool.execute( + "test_not_null_constraint", lambda _: None, index_sql + ) + ) + + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) + ) + + # Now lets do the migration + + table2_sqlite = """ + CREATE TABLE test_constraint2( + a INT PRIMARY KEY, + b INT, + CONSTRAINT test_constraint_name CHECK (b is NOT NULL) + ); + """ + + def delta(txn: LoggingTransaction) -> None: + run_validate_constraint_and_delete_rows_schema_delta( + txn, + ordering=1000, + update_name="test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=NotNullConstraint("b"), + sqlite_table_name="test_constraint2", + sqlite_table_schema=table2_sqlite, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_not_null_constraint", + delta, + ) + ) + + if isinstance(self.store.database_engine, PostgresEngine): + # Postgres uses a background update + self.updates.register_background_validate_constraint_and_delete_rows( + "test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=NotNullConstraint("b"), + unique_columns=["a"], + ) + + # Tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Check the correct values are in the new table. + rows = self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ) + + self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + + # And check that invalid rows get correctly rejected. + self.get_failure( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": None}), + exc=self.store.database_engine.module.IntegrityError, + ) + + # Check the index is still there for SQLite. + if isinstance(self.store.database_engine, Sqlite3Engine): + # Ensure the index exists in the schema. + self.get_success( + self.store.db_pool.simple_select_one_onecol( + table="sqlite_master", + keyvalues={"tbl_name": "test_constraint"}, + retcol="name", + ) + ) + + def test_foreign_constraint(self) -> None: + """Tests adding a not foreign key constraint.""" + + # Create the initial tables, where we have some invalid data. + base_sql = """ + CREATE TABLE base_table( + b INT PRIMARY KEY + ); + """ + + table_sql = """ + CREATE TABLE test_constraint( + a INT PRIMARY KEY, + b INT NOT NULL + ); + """ + self.get_success( + self.store.db_pool.execute( + "test_foreign_key_constraint", lambda _: None, base_sql + ) + ) + self.get_success( + self.store.db_pool.execute( + "test_foreign_key_constraint", lambda _: None, table_sql + ) + ) + + self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 1})) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 1, "b": 1}) + ) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}) + ) + self.get_success(self.store.db_pool.simple_insert("base_table", {"b": 3})) + self.get_success( + self.store.db_pool.simple_insert("test_constraint", {"a": 3, "b": 3}) + ) + + table2_sqlite = """ + CREATE TABLE test_constraint2( + a INT PRIMARY KEY, + b INT NOT NULL, + CONSTRAINT test_constraint_name FOREIGN KEY (b) REFERENCES base_table (b) + ); + """ + + def delta(txn: LoggingTransaction) -> None: + run_validate_constraint_and_delete_rows_schema_delta( + txn, + ordering=1000, + update_name="test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=ForeignKeyConstraint( + "base_table", [("b", "b")], deferred=False + ), + sqlite_table_name="test_constraint2", + sqlite_table_schema=table2_sqlite, + ) + + self.get_success( + self.store.db_pool.runInteraction( + "test_foreign_key_constraint", + delta, + ) + ) + + if isinstance(self.store.database_engine, PostgresEngine): + # Postgres uses a background update + self.updates.register_background_validate_constraint_and_delete_rows( + "test_bg_update", + table="test_constraint", + constraint_name="test_constraint_name", + constraint=ForeignKeyConstraint( + "base_table", [("b", "b")], deferred=False + ), + unique_columns=["a"], + ) + + # Tell the DataStore that it hasn't finished all updates yet + self.store.db_pool.updates._all_done = False + + # Now let's actually drive the updates to completion + self.wait_for_background_updates() + + # Check the correct values are in the new table. + rows = self.get_success( + self.store.db_pool.simple_select_list( + table="test_constraint", + keyvalues={}, + retcols=("a", "b"), + ) + ) + self.assertCountEqual(rows, [{"a": 1, "b": 1}, {"a": 3, "b": 3}]) + + # And check that invalid rows get correctly rejected. + self.get_failure( + self.store.db_pool.simple_insert("test_constraint", {"a": 2, "b": 2}), + exc=self.store.database_engine.module.IntegrityError, + ) diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index 7de109966d..ceb9597dd3 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -120,7 +120,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_4]) + self.assertEqual(latest_event_ids, {event_id_4}) def test_basic_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -147,7 +147,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -155,7 +155,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_chain_of_fail_cleanup(self) -> None: """Test that extremities are correctly calculated in the presence of @@ -185,7 +185,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b}) # Run the background update and check it did the right thing self.run_background_update() @@ -193,7 +193,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(latest_event_ids, [event_id_b]) + self.assertEqual(latest_event_ids, {event_id_b}) def test_forked_graph_cleanup(self) -> None: r"""Test that extremities are correctly calculated in the presence of @@ -240,7 +240,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_a, event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_a, event_id_b, event_id_c}) # Run the background update and check it did the right thing self.run_background_update() @@ -248,7 +248,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase): latest_event_ids = self.get_success( self.store.get_latest_event_ids_in_room(self.room_id) ) - self.assertEqual(set(latest_event_ids), {event_id_b, event_id_c}) + self.assertEqual(latest_event_ids, {event_id_b, event_id_c}) class CleanupExtremDummyEventsTestCase(HomeserverTestCase): diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py index cd0079871c..6b9692c486 100644 --- a/tests/storage/test_client_ips.py +++ b/tests/storage/test_client_ips.py @@ -14,7 +14,7 @@ # limitations under the License. from typing import Any, Dict -from unittest.mock import Mock +from unittest.mock import AsyncMock from parameterized import parameterized @@ -30,7 +30,6 @@ from synapse.util import Clock from tests import unittest from tests.server import make_request -from tests.test_utils import make_awaitable from tests.unittest import override_config @@ -66,15 +65,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 12345678000, - }, - r, + }.items(), + r.items(), ) def test_insert_new_client_ip_none_device_id(self) -> None: @@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): lots_of_users = 100 user_id = "@user:server" - self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(lots_of_users) - ) + self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users) self.get_success( self.store.insert_client_ip( user_id, "access_token", "ip", "user_agent", "device_id" @@ -529,15 +526,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": None, "user_agent": None, "last_seen": None, - }, - r, + }.items(), + r.items(), ) # Register the background update to run again. @@ -564,15 +561,15 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }, - r, + }.items(), + r.items(), ) def test_old_user_ips_pruned(self) -> None: @@ -643,15 +640,80 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase): ) r = result2[(user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": user_id, "device_id": device_id, "ip": "ip", "user_agent": "user_agent", "last_seen": 0, - }, - r, + }.items(), + r.items(), + ) + + def test_invalid_user_agents_are_ignored(self) -> None: + # First make sure we have completed all updates. + self.wait_for_background_updates() + + user_id1 = "@user1:id" + user_id2 = "@user2:id" + device_id1 = "MY_DEVICE1" + device_id2 = "MY_DEVICE2" + access_token1 = "access_token1" + access_token2 = "access_token2" + + # Insert a user IP 1 + self.get_success( + self.store.store_device( + user_id1, + device_id1, + "display name1", + ) + ) + # Insert a user IP 2 + self.get_success( + self.store.store_device( + user_id2, + device_id2, + "display name2", + ) + ) + + self.get_success( + self.store.insert_client_ip( + user_id1, access_token1, "ip", "sync-v3-proxy-", device_id1 + ) + ) + self.get_success( + self.store.insert_client_ip( + user_id2, access_token2, "ip", "user_agent", device_id2 + ) + ) + # Force persisting to disk + self.reactor.advance(200) + + # We should see that in the DB + result = self.get_success( + self.store.db_pool.simple_select_list( + table="user_ips", + keyvalues={}, + retcols=["access_token", "ip", "user_agent", "device_id", "last_seen"], + desc="get_user_ip_and_agents", + ) + ) + + # ensure user1 is filtered out + self.assertEqual( + result, + [ + { + "access_token": access_token2, + "ip": "ip", + "user_agent": "user_agent", + "device_id": device_id2, + "last_seen": 0, + } + ], ) @@ -715,13 +777,13 @@ class ClientIpAuthTestCase(unittest.HomeserverTestCase): self.store.get_last_client_ip_by_device(self.user_id, device_id) ) r = result[(self.user_id, device_id)] - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": self.user_id, "device_id": device_id, "ip": expected_ip, "user_agent": "Mozzila pizza", "last_seen": 123456100, - }, - r, + }.items(), + r.items(), ) diff --git a/tests/storage/test_devices.py b/tests/storage/test_devices.py index f03807c8f9..58ab41cf26 100644 --- a/tests/storage/test_devices.py +++ b/tests/storage/test_devices.py @@ -58,13 +58,13 @@ class DeviceStoreTestCase(HomeserverTestCase): res = self.get_success(self.store.get_device("user_id", "device_id")) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device_id", "display_name": "display_name", - }, - res, + }.items(), + res.items(), ) def test_get_devices_by_user(self) -> None: @@ -80,21 +80,21 @@ class DeviceStoreTestCase(HomeserverTestCase): res = self.get_success(self.store.get_devices_by_user("user_id")) self.assertEqual(2, len(res.keys())) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device1", "display_name": "display_name 1", - }, - res["device1"], + }.items(), + res["device1"].items(), ) - self.assertDictContainsSubset( + self.assertLessEqual( { "user_id": "user_id", "device_id": "device2", "display_name": "display_name 2", - }, - res["device2"], + }.items(), + res["device2"].items(), ) def test_count_devices_by_users(self) -> None: diff --git a/tests/storage/test_e2e_room_keys.py b/tests/storage/test_e2e_room_keys.py index 9cb326d90a..f6df31aba4 100644 --- a/tests/storage/test_e2e_room_keys.py +++ b/tests/storage/test_e2e_room_keys.py @@ -31,7 +31,7 @@ room_key: RoomKey = { class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver("server", federation_http_client=None) + hs = self.setup_test_homeserver("server") self.store = hs.get_datastores().main return hs diff --git a/tests/storage/test_end_to_end_keys.py b/tests/storage/test_end_to_end_keys.py index 5fde3b9c78..2033377b52 100644 --- a/tests/storage/test_end_to_end_keys.py +++ b/tests/storage/test_end_to_end_keys.py @@ -38,7 +38,7 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset(json, dev) + self.assertLessEqual(json.items(), dev.items()) def test_reupload_key(self) -> None: now = 1470174257070 @@ -71,8 +71,12 @@ class EndToEndKeyStoreTestCase(HomeserverTestCase): self.assertIn("user", res) self.assertIn("device", res["user"]) dev = res["user"]["device"] - self.assertDictContainsSubset( - {"key": "value", "unsigned": {"device_display_name": "display_name"}}, dev + self.assertLessEqual( + { + "key": "value", + "unsigned": {"device_display_name": "display_name"}, + }.items(), + dev.items(), ) def test_multiple_devices(self) -> None: diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py index e39b63edac..b55dd07f14 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py @@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase): assert persist_events_store is not None persist_events_store._store_event_txn( txn, - [(e, EventContext(self.hs.get_storage_controllers())) for e in events], + [ + (e, EventContext(self.hs.get_storage_controllers(), {})) + for e in events + ], ) # Actually call the function that calculates the auth chain stuff. @@ -661,7 +664,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Add a bunch of state so that it takes multiple iterations of the # background update to process the room. - for i in range(0, 150): + for i in range(150): self.helper.send_state( room_id, event_type="m.test", body={"index": i}, tok=self.token ) @@ -715,12 +718,12 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Add a bunch of state so that it takes multiple iterations of the # background update to process the room. - for i in range(0, 150): + for i in range(150): self.helper.send_state( room_id1, event_type="m.test", body={"index": i}, tok=self.token ) - for i in range(0, 150): + for i in range(150): self.helper.send_state( room_id2, event_type="m.test", body={"index": i}, tok=self.token ) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 81e50bdd55..d3e20f44b2 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -13,7 +13,19 @@ # limitations under the License. import datetime -from typing import Dict, List, Tuple, Union, cast +from typing import ( + Collection, + Dict, + FrozenSet, + Iterable, + List, + Mapping, + Set, + Tuple, + TypeVar, + Union, + cast, +) import attr from parameterized import parameterized @@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder import tests.unittest import tests.utils +# The silly auth graph we use to test the auth difference algorithm, +# where the top are the most recent events. +# +# A B +# \ / +# D E +# \ | +# ` F C +# | /| +# G ´ | +# | \ | +# H I +# | | +# K J + +AUTH_GRAPH: Dict[str, List[str]] = { + "a": ["e"], + "b": ["e"], + "c": ["g", "i"], + "d": ["f"], + "e": ["f"], + "f": ["g"], + "g": ["h", "i"], + "h": ["k"], + "i": ["j"], + "k": [], + "j": [], +} + +DEPTH_GRAPH = { + "a": 7, + "b": 7, + "c": 4, + "d": 6, + "e": 6, + "f": 5, + "g": 3, + "h": 2, + "i": 2, + "k": 1, + "j": 1, +} + +T = TypeVar("T") + + +def get_all_topologically_sorted_orders( + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], +) -> List[List[T]]: + """Given a set of nodes and a graph, return all possible topological + orderings. + """ + + # This is implemented by Kahn's algorithm, and forking execution each time + # we have a choice over which node to consider next. + + degree_map = {node: 0 for node in nodes} + reverse_graph: Dict[T, Set[T]] = {} + + for node, edges in graph.items(): + if node not in degree_map: + continue + + for edge in set(edges): + if edge in degree_map: + degree_map[node] += 1 + + reverse_graph.setdefault(edge, set()).add(node) + reverse_graph.setdefault(node, set()) + + zero_degree = [node for node, degree in degree_map.items() if degree == 0] + + return _get_all_topologically_sorted_orders_inner( + reverse_graph, zero_degree, degree_map + ) + + +def _get_all_topologically_sorted_orders_inner( + reverse_graph: Dict[T, Set[T]], + zero_degree: List[T], + degree_map: Dict[T, int], +) -> List[List[T]]: + new_paths = [] + + # Rather than only choosing *one* item from the list of nodes with zero + # degree, we "fork" execution and run the algorithm for each node in the + # zero degree. + for node in zero_degree: + new_degree_map = degree_map.copy() + new_zero_degree = zero_degree.copy() + new_zero_degree.remove(node) + + for edge in reverse_graph.get(node, []): + if edge in new_degree_map: + new_degree_map[edge] -= 1 + if new_degree_map[edge] == 0: + new_zero_degree.append(edge) + + paths = _get_all_topologically_sorted_orders_inner( + reverse_graph, new_zero_degree, new_degree_map + ) + for path in paths: + path.insert(0, node) + + new_paths.extend(paths) + + if not new_paths: + return [[]] + + return new_paths + + +def get_all_topologically_consistent_subsets( + nodes: Iterable[T], + graph: Mapping[T, Collection[T]], +) -> Set[FrozenSet[T]]: + """Get all subsets of the graph where if node N is in the subgraph, then all + nodes that can reach that node (i.e. for all X there exists a path X -> N) + are in the subgraph. + """ + all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph) + + graph_subsets = set() + for ordering in all_topological_orderings: + ordering.reverse() + + for idx in range(len(ordering)): + graph_subsets.add(frozenset(ordering[:idx])) + + return graph_subsets + @attr.s(auto_attribs=True, frozen=True, slots=True) class _BackfillSetupInfo: @@ -83,7 +227,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): (room_id, event_id), ) - for i in range(0, 20): + for i in range(20): self.get_success( self.store.db_pool.runInteraction("insert", insert_event, i) ) @@ -91,7 +235,7 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # this should get the last ten r = self.get_success(self.store.get_prev_events_for_room(room_id)) self.assertEqual(10, len(r)) - for i in range(0, 10): + for i in range(10): self.assertEqual("$event_%i:local" % (19 - i), r[i]) def test_get_rooms_with_many_extremities(self) -> None: @@ -99,8 +243,32 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): room2 = "#room2" room3 = "#room3" - def insert_event(txn: Cursor, i: int, room_id: str) -> None: + def insert_event(txn: LoggingTransaction, i: int, room_id: str) -> None: event_id = "$event_%i:local" % i + + # We need to insert into events table to get around the foreign key constraint. + self.store.db_pool.simple_insert_txn( + txn, + table="events", + values={ + "instance_name": "master", + "stream_ordering": self.store._stream_id_gen.get_next_txn(txn), + "topological_ordering": 1, + "depth": 1, + "event_id": event_id, + "room_id": room_id, + "type": EventTypes.Message, + "processed": True, + "outlier": False, + "origin_server_ts": 0, + "received_ts": 0, + "sender": "@user:local", + "contains_url": False, + "state_key": None, + "rejection_reason": None, + }, + ) + txn.execute( ( "INSERT INTO event_forward_extremities (room_id, event_id) " @@ -109,15 +277,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): (room_id, event_id), ) - for i in range(0, 20): + for i in range(20): self.get_success( self.store.db_pool.runInteraction("insert", insert_event, i, room1) ) self.get_success( - self.store.db_pool.runInteraction("insert", insert_event, i, room2) + self.store.db_pool.runInteraction( + "insert", insert_event, i + 100, room2 + ) ) self.get_success( - self.store.db_pool.runInteraction("insert", insert_event, i, room3) + self.store.db_pool.runInteraction( + "insert", insert_event, i + 200, room3 + ) ) # Test simple case @@ -144,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" - # The silly auth graph we use to test the auth difference algorithm, - # where the top are the most recent events. - # - # A B - # \ / - # D E - # \ | - # ` F C - # | /| - # G ´ | - # | \ | - # H I - # | | - # K J - - auth_graph: Dict[str, List[str]] = { - "a": ["e"], - "b": ["e"], - "c": ["g", "i"], - "d": ["f"], - "e": ["f"], - "f": ["g"], - "g": ["h", "i"], - "h": ["k"], - "i": ["j"], - "k": [], - "j": [], - } - - depth_map = { - "a": 7, - "b": 7, - "c": 4, - "d": 6, - "e": 6, - "f": 5, - "g": 3, - "h": 2, - "i": 2, - "k": 1, - "j": 1, - } - # Mark the room as maybe having a cover index. def store_room(txn: LoggingTransaction) -> None: @@ -210,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): def insert_event(txn: LoggingTransaction) -> None: stream_ordering = 0 - for event_id in auth_graph: + for event_id in AUTH_GRAPH: stream_ordering += 1 - depth = depth_map[event_id] + depth = DEPTH_GRAPH[event_id] self.store.db_pool.simple_insert_txn( txn, @@ -232,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): self.persist_events._persist_event_auth_chain_txn( txn, [ - cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) - for event_id in auth_graph + cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) + for event_id in AUTH_GRAPH ], ) @@ -316,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): room_id = self._setup_auth_chain(use_chain_cover_index) # Now actually test that various combinations give the right result: + self.assert_auth_diff_is_expected(room_id) + + @parameterized.expand( + [ + [graph_subset] + for graph_subset in get_all_topologically_consistent_subsets( + AUTH_GRAPH, AUTH_GRAPH + ) + ] + ) + def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None: + """Test that if we only have a chain cover index on a partial subset of + the room we still get the correct auth chain difference. + + We do this by removing the chain cover index for every valid subset of the + graph. + """ + room_id = self._setup_auth_chain(True) + + for event_id in graph_subset: + # Remove chain cover from that event. + self.get_success( + self.store.db_pool.simple_delete( + table="event_auth_chains", + keyvalues={"event_id": event_id}, + desc="test_auth_difference_partial_remove", + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + table="event_auth_chain_to_calculate", + values={ + "event_id": event_id, + "room_id": room_id, + "type": "", + "state_key": "", + }, + desc="test_auth_difference_partial_remove", + ) + ) + + self.assert_auth_diff_is_expected(room_id) + def assert_auth_diff_is_expected(self, room_id: str) -> None: + """Assert the auth chain difference returns the correct answers.""" difference = self.get_success( self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}]) ) @@ -924,215 +1097,42 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"]) - def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo: + def test_get_event_ids_with_failed_pull_attempts(self) -> None: """ - Sets up a room with various insertion event backward extremities to test - backfill functions against. - - Returns: - _BackfillSetupInfo including the `room_id` to test against and - `depth_map` of events in the room + Test to make sure we properly get event_ids based on whether they have any + failed pull attempts. """ - room_id = "!backfill-room-test:some-host" - - depth_map: Dict[str, int] = { - "1": 1, - "2": 2, - "insertion_eventA": 3, - "3": 4, - "insertion_eventB": 5, - "4": 6, - "5": 7, - } - - def populate_db(txn: LoggingTransaction) -> None: - # Insert the room to satisfy the foreign key constraint of - # `event_failed_pull_attempts` - self.store.db_pool.simple_insert_txn( - txn, - "rooms", - { - "room_id": room_id, - "creator": "room_creator_user_id", - "is_public": True, - "room_version": "6", - }, - ) - - # Insert our server events - stream_ordering = 0 - for event_id, depth in depth_map.items(): - self.store.db_pool.simple_insert_txn( - txn, - table="events", - values={ - "event_id": event_id, - "type": EventTypes.MSC2716_INSERTION - if event_id.startswith("insertion_event") - else "test_regular_type", - "room_id": room_id, - "depth": depth, - "topological_ordering": depth, - "stream_ordering": stream_ordering, - "processed": True, - "outlier": False, - }, - ) - - if event_id.startswith("insertion_event"): - self.store.db_pool.simple_insert_txn( - txn, - table="insertion_event_extremities", - values={ - "event_id": event_id, - "room_id": room_id, - }, - ) - - stream_ordering += 1 - - self.get_success( - self.store.db_pool.runInteraction( - "_setup_room_for_insertion_backfill_tests_populate_db", - populate_db, - ) - ) - - return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map) - - def test_get_insertion_event_backward_extremities_in_room(self) -> None: - """ - Test to make sure only insertion event backward extremities that are - older and come before the `current_depth` are returned. - """ - setup_info = self._setup_room_for_insertion_backfill_tests() - room_id = setup_info.room_id - depth_map = setup_info.depth_map - - # Try at "insertion_eventB" - backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room( - room_id, depth_map["insertion_eventB"], limit=100 - ) - ) - backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"]) - - # Try at "insertion_eventA" - backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room( - room_id, depth_map["insertion_eventA"], limit=100 - ) - ) - backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - # Event "2" has a depth of 2 but is not included here because we only - # know the approximate depth of 5 from our event "3". - self.assertListEqual(backfill_event_ids, ["insertion_eventA"]) - - def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted( - self, - ) -> None: - """ - Test to make sure that insertion events we have attempted to backfill - (and within backoff timeout duration) do not show up as an event to - backfill again. - """ - setup_info = self._setup_room_for_insertion_backfill_tests() - room_id = setup_info.room_id - depth_map = setup_info.depth_map - - # Record some attempts to backfill these events which will make - # `get_insertion_event_backward_extremities_in_room` exclude them - # because we haven't passed the backoff interval. - self.get_success( - self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventA", "fake cause" - ) - ) - - # No time has passed since we attempted to backfill ^ - - # Try at "insertion_eventB" - backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room( - room_id, depth_map["insertion_eventB"], limit=100 - ) - ) - backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - # Only the backfill points that we didn't record earlier exist here. - self.assertEqual(backfill_event_ids, ["insertion_eventB"]) - - def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration( - self, - ) -> None: - """ - Test to make sure after we fake attempt to backfill event - "insertion_eventA" many times, we can see retry and see the - "insertion_eventA" again after the backoff timeout duration has - exceeded. - """ - setup_info = self._setup_room_for_insertion_backfill_tests() - room_id = setup_info.room_id - depth_map = setup_info.depth_map + # Create the room + user_id = self.register_user("alice", "test") + tok = self.login("alice", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) - # Record some attempts to backfill these events which will make - # `get_backfill_points_in_room` exclude them because we - # haven't passed the backoff interval. - self.get_success( - self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventB", "fake cause" - ) - ) - self.get_success( - self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventA", "fake cause" - ) - ) - self.get_success( - self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventA", "fake cause" - ) - ) self.get_success( self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventA", "fake cause" + room_id, "$failed_event_id1", "fake cause" ) ) self.get_success( self.store.record_event_failed_pull_attempt( - room_id, "insertion_eventA", "fake cause" + room_id, "$failed_event_id2", "fake cause" ) ) - # Now advance time by 2 hours and we should only be able to see - # "insertion_eventB" because we have waited long enough for the single - # attempt (2^1 hours) but we still shouldn't see "insertion_eventA" - # because we haven't waited long enough for this many attempts. - self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - - # Try at "insertion_eventA" and make sure that "insertion_eventA" is not - # in the list because we've already attempted many times - backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room( - room_id, depth_map["insertion_eventA"], limit=100 + event_ids_with_failed_pull_attempts = self.get_success( + self.store.get_event_ids_with_failed_pull_attempts( + event_ids=[ + "$failed_event_id1", + "$fresh_event_id1", + "$failed_event_id2", + "$fresh_event_id2", + ] ) ) - backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertEqual(backfill_event_ids, []) - - # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and - # see if we can now backfill it - self.reactor.advance(datetime.timedelta(hours=20).total_seconds()) - # Try at "insertion_eventA" again after we advanced enough time and we - # should see "insertion_eventA" again - backfill_points = self.get_success( - self.store.get_insertion_event_backward_extremities_in_room( - room_id, depth_map["insertion_eventA"], limit=100 - ) + self.assertEqual( + event_ids_with_failed_pull_attempts, + {"$failed_event_id1", "$failed_event_id2"}, ) - backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points] - self.assertEqual(backfill_event_ids, ["insertion_eventA"]) def test_get_event_ids_to_not_pull_from_backoff(self) -> None: """ diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py deleted file mode 100644 index 5d7c13e6d0..0000000000 --- a/tests/storage/test_keys.py +++ /dev/null @@ -1,137 +0,0 @@ -# Copyright 2017 Vector Creations Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import signedjson.key -import signedjson.types -import unpaddedbase64 - -from synapse.storage.keys import FetchKeyResult - -import tests.unittest - - -def decode_verify_key_base64( - key_id: str, key_base64: str -) -> signedjson.types.VerifyKey: - key_bytes = unpaddedbase64.decode_base64(key_base64) - return signedjson.key.decode_verify_key_bytes(key_id, key_bytes) - - -KEY_1 = decode_verify_key_base64( - "ed25519:key1", "fP5l4JzpZPq/zdbBg5xx6lQGAAOM9/3w94cqiJ5jPrw" -) -KEY_2 = decode_verify_key_base64( - "ed25519:key2", "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw" -) - - -class KeyStoreTestCase(tests.unittest.HomeserverTestCase): - def test_get_server_signature_keys(self) -> None: - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:KEY_ID_2" - self.get_success( - store.store_server_signature_keys( - "from_server", - 10, - { - ("server1", key_id_1): FetchKeyResult(KEY_1, 100), - ("server1", key_id_2): FetchKeyResult(KEY_2, 200), - }, - ) - ) - - res = self.get_success( - store.get_server_signature_keys( - [ - ("server1", key_id_1), - ("server1", key_id_2), - ("server1", "ed25519:key3"), - ] - ) - ) - - self.assertEqual(len(res.keys()), 3) - res1 = res[("server1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.verify_key.version, "key1") - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("server1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - # version comes from the ID it was stored with - self.assertEqual(res2.verify_key.version, "KEY_ID_2") - self.assertEqual(res2.valid_until_ts, 200) - - # non-existent result gives None - self.assertIsNone(res[("server1", "ed25519:key3")]) - - def test_cache(self) -> None: - """Check that updates correctly invalidate the cache.""" - - store = self.hs.get_datastores().main - - key_id_1 = "ed25519:key1" - key_id_2 = "ed25519:key2" - - self.get_success( - store.store_server_signature_keys( - "from_server", - 0, - { - ("srv1", key_id_1): FetchKeyResult(KEY_1, 100), - ("srv1", key_id_2): FetchKeyResult(KEY_2, 200), - }, - ) - ) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, KEY_2) - self.assertEqual(res2.valid_until_ts, 200) - - # we should be able to look up the same thing again without a db hit - res = self.get_success(store.get_server_signature_keys([("srv1", key_id_1)])) - self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) - - new_key_2 = signedjson.key.get_verify_key( - signedjson.key.generate_signing_key("key2") - ) - d = store.store_server_signature_keys( - "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)} - ) - self.get_success(d) - - res = self.get_success( - store.get_server_signature_keys([("srv1", key_id_1), ("srv1", key_id_2)]) - ) - self.assertEqual(len(res.keys()), 2) - - res1 = res[("srv1", key_id_1)] - self.assertEqual(res1.verify_key, KEY_1) - self.assertEqual(res1.valid_until_ts, 100) - - res2 = res[("srv1", key_id_2)] - self.assertEqual(res2.verify_key, new_key_2) - self.assertEqual(res2.valid_until_ts, 300) diff --git a/tests/storage/test_main.py b/tests/storage/test_main.py index 27f450e22d..b8823d6993 100644 --- a/tests/storage/test_main.py +++ b/tests/storage/test_main.py @@ -20,7 +20,7 @@ from tests import unittest class DataStoreTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(DataStoreTestCase, self).setUp() + super().setUp() self.store = self.hs.get_datastores().main diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 2827738379..49366440ce 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List -from unittest.mock import Mock +from unittest.mock import AsyncMock from twisted.test.proto_helpers import MemoryReactor @@ -21,7 +21,6 @@ from synapse.server import HomeServer from synapse.util import Clock from tests import unittest -from tests.test_utils import make_awaitable from tests.unittest import default_config, override_config FORTY_DAYS = 40 * 24 * 60 * 60 @@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): ) self.get_success(d) - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] d = self.store.populate_monthly_active_users(user_id) self.get_success(d) @@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.store.upsert_monthly_active_user.assert_not_called() def test_populate_monthly_users_should_update(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(None) - ) + self.store.user_last_seen_monthly_active = AsyncMock(return_value=None) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) self.store.upsert_monthly_active_user.assert_called_once() def test_populate_monthly_users_should_not_update(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] - self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] - self.store.user_last_seen_monthly_active = Mock( - return_value=make_awaitable(self.hs.get_clock().time_msec()) + self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[method-assign] + self.store.user_last_seen_monthly_active = AsyncMock( + return_value=self.hs.get_clock().time_msec() ) d = self.store.populate_monthly_active_users("user_id") @@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": False, "mau_stats_only": False}) def test_no_users_when_not_tracking(self) -> None: - self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] + self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[method-assign] self.get_success(self.store.populate_monthly_active_users("@user:sever")) diff --git a/tests/storage/test_profile.py b/tests/storage/test_profile.py index f9cf0fcb82..95f99f4130 100644 --- a/tests/storage/test_profile.py +++ b/tests/storage/test_profile.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer @@ -35,18 +36,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( "Frank", - ( - self.get_success( - self.store.get_profile_displayname(self.u_frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_displayname(self.u_frank))), ) # test set to None self.get_success(self.store.set_profile_displayname(self.u_frank, None)) self.assertIsNone( - self.get_success(self.store.get_profile_displayname(self.u_frank.localpart)) + self.get_success(self.store.get_profile_displayname(self.u_frank)) ) def test_avatar_url(self) -> None: @@ -58,18 +55,14 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.assertEqual( "http://my.site/here", - ( - self.get_success( - self.store.get_profile_avatar_url(self.u_frank.localpart) - ) - ), + (self.get_success(self.store.get_profile_avatar_url(self.u_frank))), ) # test set to None self.get_success(self.store.set_profile_avatar_url(self.u_frank, None)) self.assertIsNone( - self.get_success(self.store.get_profile_avatar_url(self.u_frank.localpart)) + self.get_success(self.store.get_profile_avatar_url(self.u_frank)) ) def test_profiles_bg_migration(self) -> None: @@ -89,7 +82,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.db_pool.runInteraction("", f)) - for i in range(0, 70): + for i in range(70): self.get_success( self.store.db_pool.simple_insert( "profiles", @@ -122,7 +115,7 @@ class ProfileStoreTestCase(unittest.HomeserverTestCase): ) expected_values = [] - for i in range(0, 70): + for i in range(70): expected_values.append((f"@hello{i:02}:{self.hs.hostname}",)) res = self.get_success( diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 857e2caf2e..0282673167 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -27,7 +27,7 @@ class PurgeTests(HomeserverTestCase): servlets = [room.register_servlets] def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver("server", federation_http_client=None) + hs = self.setup_test_homeserver("server") return hs def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 05ea802008..0cca34d355 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -16,7 +16,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes from synapse.api.errors import ThreepidValidationError from synapse.server import HomeServer -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, UserID, UserInfo from synapse.util import Clock from tests.unittest import HomeserverTestCase, override_config @@ -35,22 +35,22 @@ class RegistrationStoreTestCase(HomeserverTestCase): self.get_success(self.store.register_user(self.user_id, self.pwhash)) self.assertEqual( - { + UserInfo( # TODO(paul): Surely this field should be 'user_id', not 'name' - "name": self.user_id, - "password_hash": self.pwhash, - "admin": 0, - "is_guest": 0, - "consent_version": None, - "consent_ts": None, - "consent_server_notice_sent": None, - "appservice_id": None, - "creation_ts": 0, - "user_type": None, - "deactivated": 0, - "shadow_banned": 0, - "approved": 1, - }, + user_id=UserID.from_string(self.user_id), + is_admin=False, + is_guest=False, + consent_server_notice_sent=None, + consent_ts=None, + consent_version=None, + appservice_id=None, + creation_ts=0, + user_type=None, + is_deactivated=False, + locked=False, + is_shadow_banned=False, + approved=True, + ), (self.get_success(self.store.get_user_by_id(self.user_id))), ) @@ -63,9 +63,11 @@ class RegistrationStoreTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user - self.assertEqual(user["consent_version"], "1") - self.assertGreater(user["consent_ts"], before_consent) - self.assertLess(user["consent_ts"], self.clock.time_msec()) + self.assertEqual(user.consent_version, "1") + self.assertIsNotNone(user.consent_ts) + assert user.consent_ts is not None + self.assertGreater(user.consent_ts, before_consent) + self.assertLess(user.consent_ts, self.clock.time_msec()) def test_add_tokens(self) -> None: self.get_success(self.store.register_user(self.user_id, self.pwhash)) @@ -213,7 +215,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertTrue(user["approved"]) + self.assertTrue(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved) @@ -226,7 +228,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) assert user is not None - self.assertFalse(user["approved"]) + self.assertFalse(user.approved) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertFalse(approved) @@ -246,7 +248,7 @@ class ApprovalRequiredRegistrationTestCase(HomeserverTestCase): user = self.get_success(self.store.get_user_by_id(self.user_id)) self.assertIsNotNone(user) assert user is not None - self.assertEqual(user["approved"], 1) + self.assertEqual(user.approved, 1) approved = self.get_success(self.store.is_user_approved(self.user_id)) self.assertTrue(approved) diff --git a/tests/storage/test_rollback_worker.py b/tests/storage/test_rollback_worker.py index 966aafea6f..809c9f175d 100644 --- a/tests/storage/test_rollback_worker.py +++ b/tests/storage/test_rollback_worker.py @@ -45,9 +45,7 @@ def fake_listdir(filepath: str) -> List[str]: class WorkerSchemaTests(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - hs = self.setup_test_homeserver( - federation_http_client=None, homeserver_to_use=GenericWorkerServer - ) + hs = self.setup_test_homeserver(homeserver_to_use=GenericWorkerServer) return hs def default_config(self) -> JsonDict: @@ -55,6 +53,7 @@ class WorkerSchemaTests(HomeserverTestCase): # Mark this as a worker app. conf["worker_app"] = "yes" + conf["instance_map"] = {"main": {"host": "127.0.0.1", "port": 0}} return conf diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index 71ec74eadc..1e27f2c275 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -44,13 +44,13 @@ class RoomStoreTestCase(HomeserverTestCase): def test_get_room(self) -> None: res = self.get_success(self.store.get_room(self.room.to_string())) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "is_public": True, - }, - res, + }.items(), + res.items(), ) def test_get_room_unknown_room(self) -> None: @@ -59,13 +59,13 @@ class RoomStoreTestCase(HomeserverTestCase): def test_get_room_with_stats(self) -> None: res = self.get_success(self.store.get_room_with_stats(self.room.to_string())) assert res is not None - self.assertDictContainsSubset( + self.assertLessEqual( { "room_id": self.room.to_string(), "creator": self.u_creator.to_string(), "public": True, - }, - res, + }.items(), + res.items(), ) def test_get_room_with_stats_unknown_room(self) -> None: diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py index f183c38477..52ffa91c81 100644 --- a/tests/storage/test_room_search.py +++ b/tests/storage/test_room_search.py @@ -318,14 +318,14 @@ class MessageSearchTest(HomeserverTestCase): result = self.get_success( store.search_msgs([self.room_id], query, ["content.body"]) ) - self.assertEquals( + self.assertEqual( result["count"], 1 if expect_to_contain else 0, f"expected '{query}' to match '{self.PHRASE}'" if expect_to_contain else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) - self.assertEquals( + self.assertEqual( len(result["results"]), 1 if expect_to_contain else 0, "results array length should match count", @@ -336,14 +336,14 @@ class MessageSearchTest(HomeserverTestCase): result = self.get_success( store.search_rooms([self.room_id], query, ["content.body"], 10) ) - self.assertEquals( + self.assertEqual( result["count"], 1 if expect_to_contain else 0, f"expected '{query}' to match '{self.PHRASE}'" if expect_to_contain else f"'{query}' unexpectedly matched '{self.PHRASE}'", ) - self.assertEquals( + self.assertEqual( len(result["results"]), 1 if expect_to_contain else 0, "results array length should match count", diff --git a/tests/storage/test_transactions.py b/tests/storage/test_transactions.py index db9ee9955e..ef06b50dbb 100644 --- a/tests/storage/test_transactions.py +++ b/tests/storage/test_transactions.py @@ -17,7 +17,6 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.server import HomeServer from synapse.storage.databases.main.transactions import DestinationRetryTimings from synapse.util import Clock -from synapse.util.retryutils import MAX_RETRY_INTERVAL from tests.unittest import HomeserverTestCase @@ -33,15 +32,14 @@ class TransactionStoreTestCase(HomeserverTestCase): destination retries, as well as testing tht we can set and get correctly. """ - d = self.store.get_destination_retry_timings("example.com") - r = self.get_success(d) + r = self.get_success(self.store.get_destination_retry_timings("example.com")) self.assertIsNone(r) - d = self.store.set_destination_retry_timings("example.com", 1000, 50, 100) - self.get_success(d) + self.get_success( + self.store.set_destination_retry_timings("example.com", 1000, 50, 100) + ) - d = self.store.get_destination_retry_timings("example.com") - r = self.get_success(d) + r = self.get_success(self.store.get_destination_retry_timings("example.com")) self.assertEqual( DestinationRetryTimings( @@ -58,8 +56,14 @@ class TransactionStoreTestCase(HomeserverTestCase): self.get_success(d) def test_large_destination_retry(self) -> None: + max_retry_interval_ms = ( + self.hs.config.federation.destination_max_retry_interval_ms + ) d = self.store.set_destination_retry_timings( - "example.com", MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL, MAX_RETRY_INTERVAL + "example.com", + max_retry_interval_ms, + max_retry_interval_ms, + max_retry_interval_ms, ) self.get_success(d) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 15ea4770bd..22f074982f 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -38,5 +38,5 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): db_pool = self.hs.get_datastores().databases[0] # force txn limit to roll over at least once - for _ in range(0, 1001): + for _ in range(1001): self.get_success_or_raise(db_pool.runInteraction("test_select", do_select)) diff --git a/tests/storage/test_user_filters.py b/tests/storage/test_user_filters.py index bab802f56e..d4637d9d1e 100644 --- a/tests/storage/test_user_filters.py +++ b/tests/storage/test_user_filters.py @@ -45,7 +45,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase): self.get_success(self.store.db_pool.runInteraction("", f)) - for i in range(0, 70): + for i in range(70): self.get_success( self.store.db_pool.simple_insert( "user_filters", @@ -82,7 +82,7 @@ class UserFiltersStoreTestCase(unittest.HomeserverTestCase): ) expected_values = [] - for i in range(0, 70): + for i in range(70): expected_values.append((f"@hello{i:02}:{self.hs.hostname}",)) res = self.get_success( diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 0e3fc2a77f..29be8cdbd0 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import ( PartialStateEventsTracker, ) -from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase): class PartialCurrentStateTrackerTestCase(TestCase): def setUp(self) -> None: self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + self.mock_store.is_partial_state_room = mock.AsyncMock() self.tracker = PartialCurrentStateTracker(self.mock_store) def test_does_not_block_for_full_state_rooms(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + self.mock_store.is_partial_state_room.return_value = False self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_blocks_for_partial_room_state(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + self.mock_store.is_partial_state_room.return_value = True d = ensureDeferred(self.tracker.await_full_state("room_id")) @@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase): self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) def test_cancellation(self) -> None: - self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + self.mock_store.is_partial_state_room.return_value = True d1 = ensureDeferred(self.tracker.await_full_state("room_id")) self.assertNoResult(d1) diff --git a/tests/test_federation.py b/tests/test_federation.py index 6d15ac7597..1b0504709e 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -13,7 +13,7 @@ # limitations under the License. from typing import Collection, List, Optional, Union -from unittest.mock import Mock +from unittest.mock import AsyncMock, Mock from twisted.test.proto_helpers import MemoryReactor @@ -31,7 +31,6 @@ from synapse.util import Clock from synapse.util.retryutils import NotRetryingDestination from tests import unittest -from tests.test_utils import make_awaitable class MessageAcceptTests(unittest.HomeserverTestCase): @@ -52,9 +51,15 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.store = self.hs.get_datastores().main # Figure out what the most recent event is - most_recent = self.get_success( - self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success( + self.hs.get_datastores().main.get_latest_event_ids_in_room( + self.room_id + ) + ) + ) + ) join_event = make_event_from_dict( { @@ -81,7 +86,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) -> None: pass - federation_event_handler._check_event_auth = _check_event_auth # type: ignore[assignment] + federation_event_handler._check_event_auth = _check_event_auth # type: ignore[method-assign] self.client = self.hs.get_federation_client() async def _check_sigs_and_hash_for_pulled_events_and_fetch( @@ -101,8 +106,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Make sure we actually joined the room self.assertEqual( - self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0], - "$join:test.serv", + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)), + {"$join:test.serv"}, ) def test_cant_hide_direct_ancestors(self) -> None: @@ -128,9 +133,11 @@ class MessageAcceptTests(unittest.HomeserverTestCase): self.http_client.post_json = post_json # Figure out what the most recent event is - most_recent = self.get_success( - self.store.get_latest_event_ids_in_room(self.room_id) - )[0] + most_recent = next( + iter( + self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) + ) + ) # Now lie about an event lying_event = make_event_from_dict( @@ -166,7 +173,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Make sure the invalid event isn't there extrem = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id)) - self.assertEqual(extrem[0], "$join:test.serv") + self.assertEqual(extrem, {"$join:test.serv"}) def test_retry_device_list_resync(self) -> None: """Tests that device lists are marked as stale if they couldn't be synced, and @@ -191,12 +198,12 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register the mock on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[assignment] + federation_client.query_user_devices = Mock(side_effect=query_user_devices) # type: ignore[method-assign] # Register a mock on the store so that the incoming update doesn't fail because # we don't share a room with the user. store = self.hs.get_datastores().main - store.get_rooms_for_user = Mock(return_value=make_awaitable(["!someroom:test"])) + store.get_rooms_for_user = AsyncMock(return_value=["!someroom:test"]) # Manually inject a fake device list update. We need this update to include at # least one prev_id so that the user's device list will need to be retried. @@ -241,27 +248,24 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.hs.get_federation_client() - federation_client.query_user_devices = Mock( # type: ignore[assignment] - return_value=make_awaitable( - { + federation_client.query_user_devices = AsyncMock( # type: ignore[method-assign] + return_value={ + "user_id": remote_user_id, + "stream_id": 1, + "devices": [], + "master_key": { "user_id": remote_user_id, - "stream_id": 1, - "devices": [], - "master_key": { - "user_id": remote_user_id, - "usage": ["master"], - "keys": {"ed25519:" + remote_master_key: remote_master_key}, - }, - "self_signing_key": { - "user_id": remote_user_id, - "usage": ["self_signing"], - "keys": { - "ed25519:" - + remote_self_signing_key: remote_self_signing_key - }, + "usage": ["master"], + "keys": {"ed25519:" + remote_master_key: remote_master_key}, + }, + "self_signing_key": { + "user_id": remote_user_id, + "usage": ["self_signing"], + "keys": { + "ed25519:" + remote_self_signing_key: remote_self_signing_key }, - } - ) + }, + } ) # Resync the device list. diff --git a/tests/test_server.py b/tests/test_server.py index e266c06a2c..36162cd1f5 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -38,7 +38,7 @@ from tests.http.server._base import test_disconnect from tests.server import ( FakeChannel, FakeSite, - ThreadedMemoryReactorClock, + get_clock, make_request, setup_test_homeserver, ) @@ -46,12 +46,11 @@ from tests.server import ( class JsonResourceTests(unittest.TestCase): def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() - self.hs_clock = Clock(self.reactor) + reactor, clock = get_clock() + self.reactor = reactor self.homeserver = setup_test_homeserver( self.addCleanup, - federation_http_client=None, - clock=self.hs_clock, + clock=clock, reactor=self.reactor, ) @@ -209,7 +208,13 @@ class JsonResourceTests(unittest.TestCase): class OptionsResourceTests(unittest.TestCase): def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() + reactor, clock = get_clock() + self.reactor = reactor + self.homeserver = setup_test_homeserver( + self.addCleanup, + clock=clock, + reactor=self.reactor, + ) class DummyResource(Resource): isLeaf = True @@ -242,6 +247,7 @@ class OptionsResourceTests(unittest.TestCase): "1.0", max_request_body_size=4096, reactor=self.reactor, + hs=self.homeserver, ) # render the request and return the channel @@ -268,7 +274,7 @@ class OptionsResourceTests(unittest.TestCase): ) self.assertEqual( channel.headers.getRawHeaders(b"Access-Control-Expose-Headers"), - [b"Synapse-Trace-Id"], + [b"Synapse-Trace-Id, Server"], ) def _check_cors_msc3886_headers(self, channel: FakeChannel) -> None: @@ -344,7 +350,8 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): await self.callback(request) def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() + reactor, _ = get_clock() + self.reactor = reactor def test_good_response(self) -> None: async def callback(request: SynapseRequest) -> None: @@ -462,9 +469,9 @@ class DirectServeJsonResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeJsonResource` cancellation.""" def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() - self.clock = Clock(self.reactor) - self.resource = CancellableDirectServeJsonResource(self.clock) + reactor, clock = get_clock() + self.reactor = reactor + self.resource = CancellableDirectServeJsonResource(clock) self.site = FakeSite(self.resource, self.reactor) def test_cancellable_disconnect(self) -> None: @@ -496,9 +503,9 @@ class DirectServeHtmlResourceCancellationTests(unittest.TestCase): """Tests for `DirectServeHtmlResource` cancellation.""" def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() - self.clock = Clock(self.reactor) - self.resource = CancellableDirectServeHtmlResource(self.clock) + reactor, clock = get_clock() + self.reactor = reactor + self.resource = CancellableDirectServeHtmlResource(clock) self.site = FakeSite(self.resource, self.reactor) def test_cancellable_disconnect(self) -> None: diff --git a/tests/test_state.py b/tests/test_state.py index ddf59916b1..9c8679cc1d 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -28,7 +28,7 @@ from unittest.mock import Mock from twisted.internet import defer -from synapse.api.auth import Auth +from synapse.api.auth.internal import InternalAuth from synapse.api.constants import EventTypes, Membership from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, make_event_from_dict @@ -240,7 +240,7 @@ class StateTestCase(unittest.TestCase): hs.get_macaroon_generator.return_value = MacaroonGenerator( clock, "tesths", b"verysecret" ) - hs.get_auth.return_value = Auth(hs) + hs.get_auth.return_value = InternalAuth(hs) hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs) hs.get_storage_controllers.return_value = storage_controllers @@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase): (e.event_id for e in old_state + [event]), current_state_ids.values() ) - self.assertIsNotNone(context.state_group_before_event) + assert context.state_group_before_event is not None + assert context.state_group is not None + self.assertEqual( + context.state_group_deltas.get( + (context.state_group_before_event, context.state_group) + ), + {(event.type, event.state_key): event.event_id}, + ) self.assertNotEqual(context.state_group_before_event, context.state_group) - self.assertEqual(context.state_group_before_event, context.prev_group) - self.assertEqual({("state", ""): event.event_id}, context.delta_ids) @defer.inlineCallbacks def test_trivial_annotate_message( @@ -709,7 +714,7 @@ class StateTestCase(unittest.TestCase): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( @@ -768,7 +773,7 @@ class StateTestCase(unittest.TestCase): store = _DummyStore() store.register_events(old_state_1) store.register_events(old_state_2) - self.dummy_store.get_events = store.get_events # type: ignore[assignment] + self.dummy_store.get_events = store.get_events # type: ignore[method-assign] context: EventContext context = yield self._get_context( diff --git a/tests/test_terms_auth.py b/tests/test_terms_auth.py index 52424aa087..64a49488c6 100644 --- a/tests/test_terms_auth.py +++ b/tests/test_terms_auth.py @@ -85,7 +85,9 @@ class TermsTestCase(unittest.HomeserverTestCase): } } self.assertIsInstance(channel.json_body["params"], dict) - self.assertDictContainsSubset(channel.json_body["params"], expected_params) + self.assertLessEqual( + channel.json_body["params"].items(), expected_params.items() + ) # We have to complete the dummy auth stage before completing the terms stage request_data = { diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index e5dae670a7..fa731426cd 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -18,10 +18,8 @@ Utilities for running the unit tests import json import sys import warnings -from asyncio import Future from binascii import unhexlify -from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar -from unittest.mock import Mock +from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar import attr import zope.interface @@ -33,7 +31,7 @@ from twisted.web.http import RESPONSES from twisted.web.http_headers import Headers from twisted.web.iweb import IResponse -from synapse.types import JsonDict +from synapse.types import JsonSerializable if TYPE_CHECKING: from sys import UnraisableHookArgs @@ -57,27 +55,12 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV: raise Exception("awaitable has not yet completed") -def make_awaitable(result: TV) -> Awaitable[TV]: - """ - Makes an awaitable, suitable for mocking an `async` function. - This uses Futures as they can be awaited multiple times so can be returned - to multiple callers. - """ - future: Future[TV] = Future() - future.set_result(result) - return future - - def setup_awaitable_errors() -> Callable[[], None]: """ Convert warnings from a non-awaited coroutines into errors. """ warnings.simplefilter("error", RuntimeWarning) - # unraisablehook was added in Python 3.8. - if not hasattr(sys, "unraisablehook"): - return lambda: None - # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] orig_unraisablehook = sys.unraisablehook @@ -100,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]: return cleanup -def simple_async_mock( - return_value: Optional[TV] = None, raises: Optional[Exception] = None -) -> Mock: - # AsyncMock is not available in python3.5, this mimics part of its behaviour - async def cb(*args: Any, **kwargs: Any) -> Optional[TV]: - if raises: - raise raises - return return_value - - return Mock(side_effect=cb) - - # Type ignore: it does not fully implement IResponse, but is good enough for tests @zope.interface.implementer(IResponse) @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -145,7 +116,7 @@ class FakeResponse: # type: ignore[misc] protocol.connectionLost(Failure(ResponseDone())) @classmethod - def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse": + def json(cls, *, code: int = 200, payload: JsonSerializable) -> "FakeResponse": headers = Headers({"Content-Type": ["application/json"]}) body = json.dumps(payload).encode("utf-8") return cls(code=code, body=body, headers=headers) diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index b522163a34..199bb06a81 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -40,10 +40,9 @@ def setup_logging() -> None: """ root_logger = logging.getLogger() - log_format = ( - "%(asctime)s - %(name)s - %(lineno)d - " - "%(levelname)s - %(request)s - %(message)s" - ) + # We exclude `%(asctime)s` from this format because the Twisted logger adds its own + # timestamp + log_format = "%(name)s - %(lineno)d - " "%(levelname)s - %(request)s - %(message)s" handler = ToTwistedHandler() formatter = logging.Formatter(log_format) @@ -54,4 +53,16 @@ def setup_logging() -> None: log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") root_logger.setLevel(log_level) + # In order to not add noise by default (since we only log ERROR messages for trial + # tests as configured above), we only enable this for developers for looking for + # more INFO or DEBUG. + if root_logger.isEnabledFor(logging.INFO): + # Log when events are (maybe unexpectedly) filtered out of responses in tests. It's + # just nice to be able to look at the CI log and figure out why an event isn't being + # returned. + logging.getLogger("synapse.visibility.filtered_event_debug").setLevel( + logging.DEBUG + ) + + # Blow away the pyo3-log cache so that it reloads the configuration. reset_logging_config() diff --git a/tests/test_visibility.py b/tests/test_visibility.py index 9ed330f554..434902c3f0 100644 --- a/tests/test_visibility.py +++ b/tests/test_visibility.py @@ -31,7 +31,7 @@ TEST_ROOM_ID = "!TEST:ROOM" class FilterEventsForServerTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - super(FilterEventsForServerTestCase, self).setUp() + super().setUp() self.event_creation_handler = self.hs.get_event_creation_handler() self.event_builder_factory = self.hs.get_event_builder_factory() self._storage_controllers = self.hs.get_storage_controllers() @@ -51,12 +51,12 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): # before we do that, we persist some other events to act as state. self._inject_visibility("@admin:hs", "joined") - for i in range(0, 10): + for i in range(10): self._inject_room_member("@resident%i:hs" % i) events_to_filter = [] - for i in range(0, 10): + for i in range(10): user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server") evt = self._inject_room_member(user, extra_content={"a": "b"}) events_to_filter.append(evt) @@ -74,7 +74,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): ) # the result should be 5 redacted events, and 5 unredacted events. - for i in range(0, 5): + for i in range(5): self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id) self.assertNotIn("a", filtered[i].content) @@ -177,7 +177,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase): ) ) - for i in range(0, len(events_to_filter)): + for i in range(len(events_to_filter)): self.assertEqual( events_to_filter[i].event_id, filtered[i].event_id, diff --git a/tests/unittest.py b/tests/unittest.py index 623c5a75a2..99ad02eb06 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import functools import gc import hashlib import hmac @@ -59,7 +60,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.crypto.event_signing import add_hashes_and_signatures from synapse.federation.transport.server import TransportLayerServer -from synapse.http.server import JsonResource +from synapse.http.server import JsonResource, OptionsResource from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -69,6 +70,7 @@ from synapse.logging.context import ( ) from synapse.rest import RegisterServletsFunc from synapse.server import HomeServer +from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree @@ -150,7 +152,11 @@ def deepcopy_config(config: _TConfig) -> _TConfig: return new_config -_make_homeserver_config_obj_cache: Dict[str, Union[RootConfig, Config]] = {} +@functools.lru_cache(maxsize=8) +def _parse_config_dict(config: str) -> RootConfig: + config_obj = HomeServerConfig() + config_obj.parse_config_dict(json.loads(config), "", "") + return config_obj def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig: @@ -164,21 +170,7 @@ def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig: but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed, to avoid validating the whole configuration every time. """ - cache_key = json.dumps(config) - - if cache_key in _make_homeserver_config_obj_cache: - # Cache hit: reuse the existing instance - config_obj = _make_homeserver_config_obj_cache[cache_key] - else: - # Cache miss; create the actual instance - config_obj = HomeServerConfig() - config_obj.parse_config_dict(config, "", "") - - # Add to the cache - _make_homeserver_config_obj_cache[cache_key] = config_obj - - assert isinstance(config_obj, RootConfig) - + config_obj = _parse_config_dict(json.dumps(config, sort_keys=True)) return deepcopy_config(config_obj) @@ -322,7 +314,7 @@ class HomeserverTestCase(TestCase): servlets: List of servlet registration function. user_id (str): The user ID to assume if auth is hijacked. hijack_auth: Whether to hijack auth to return the user specified - in user_id. + in user_id. """ hijack_auth: ClassVar[bool] = True @@ -367,6 +359,7 @@ class HomeserverTestCase(TestCase): server_version_string="1", max_request_body_size=4096, reactor=self.reactor, + hs=self.hs, ) from tests.rest.client.utils import RestHelper @@ -403,9 +396,9 @@ class HomeserverTestCase(TestCase): ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[method-assign] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[method-assign] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] @@ -466,7 +459,7 @@ class HomeserverTestCase(TestCase): The default calls `self.create_resource_dict` and builds the resultant dict into a tree. """ - root_resource = Resource() + root_resource = OptionsResource() create_resource_tree(self.create_resource_dict(), root_resource) return root_resource @@ -866,23 +859,22 @@ class FederatingHomeserverTestCase(HomeserverTestCase): verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version) self.get_success( - hs.get_datastores().main.store_server_keys_json( + hs.get_datastores().main.store_server_keys_response( self.OTHER_SERVER_NAME, - verify_key_id, from_server=self.OTHER_SERVER_NAME, - ts_now_ms=clock.time_msec(), - ts_expires_ms=clock.time_msec() + 10000, - key_json_bytes=canonicaljson.encode_canonical_json( - { - "verify_keys": { - verify_key_id: { - "key": signedjson.key.encode_verify_key_base64( - verify_key - ) - } + ts_added_ms=clock.time_msec(), + verify_keys={ + verify_key_id: FetchKeyResult( + verify_key=verify_key, valid_until_ts=clock.time_msec() + 10000 + ), + }, + response_json={ + "verify_keys": { + verify_key_id: { + "key": signedjson.key.encode_verify_key_base64(verify_key) } } - ), + }, ) ) diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 13f1edd533..7e8725e610 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, Set, Tuple, cast +from typing import ( + Any, + Generator, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Set, + Tuple, + cast, +) from unittest import mock from twisted.internet import defer, reactor @@ -29,7 +40,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList +from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from tests import unittest from tests.test_utils import get_awaitable_result @@ -37,21 +48,21 @@ from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -def run_on_reactor(): - d: "Deferred[int]" = defer.Deferred() +def run_on_reactor() -> "Deferred[int]": + d: "Deferred[int]" = Deferred() cast(IReactorTime, reactor).callLater(0, d.callback, 0) return make_deferred_yieldable(d) class DescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks - def test_cache(self): + def test_cache(self) -> Generator["Deferred[Any]", object, None]: class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1, arg2): + def fn(self, arg1: int, arg2: int) -> str: return self.mock(arg1, arg2) obj = Cls() @@ -77,15 +88,15 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() @defer.inlineCallbacks - def test_cache_num_args(self): + def test_cache_num_args(self) -> Generator["Deferred[Any]", object, None]: """Only the first num_args arguments should matter to the cache""" class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached(num_args=1) - def fn(self, arg1, arg2): + def fn(self, arg1: int, arg2: int) -> str: return self.mock(arg1, arg2) obj = Cls() @@ -111,7 +122,7 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() @defer.inlineCallbacks - def test_cache_uncached_args(self): + def test_cache_uncached_args(self) -> Generator["Deferred[Any]", object, None]: """ Only the arguments not named in uncached_args should matter to the cache @@ -123,10 +134,10 @@ class DescriptorTestCase(unittest.TestCase): # Note that it is important that this is not the last argument to # test behaviour of skipping arguments properly. @descriptors.cached(uncached_args=("arg2",)) - def fn(self, arg1, arg2, arg3): + def fn(self, arg1: int, arg2: int, arg3: int) -> str: return self.mock(arg1, arg2, arg3) - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() obj = Cls() @@ -152,15 +163,15 @@ class DescriptorTestCase(unittest.TestCase): obj.mock.assert_not_called() @defer.inlineCallbacks - def test_cache_kwargs(self): + def test_cache_kwargs(self) -> Generator["Deferred[Any]", object, None]: """Test that keyword arguments are treated properly""" class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1, kwarg1=2): + def fn(self, arg1: int, kwarg1: int = 2) -> str: return self.mock(arg1, kwarg1=kwarg1) obj = Cls() @@ -188,12 +199,12 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "fish") obj.mock.assert_not_called() - def test_cache_with_sync_exception(self): + def test_cache_with_sync_exception(self) -> None: """If the wrapped function throws synchronously, things should continue to work""" class Cls: @cached() - def fn(self, arg1): + def fn(self, arg1: int) -> NoReturn: raise SynapseError(100, "mai spoon iz too big!!1") obj = Cls() @@ -209,23 +220,24 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) - def test_cache_with_async_exception(self): + def test_cache_with_async_exception(self) -> None: """The wrapped function returns a failure""" class Cls: - result = None + result: Optional[Deferred] = None call_count = 0 @cached() - def fn(self, arg1): + def fn(self, arg1: int) -> Deferred: self.call_count += 1 + assert self.result is not None return self.result obj = Cls() callbacks: Set[str] = set() # set off an asynchronous request - origin_d: Deferred = defer.Deferred() + origin_d: Deferred = Deferred() obj.result = origin_d d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) @@ -260,17 +272,17 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(self.successResultOf(d3), 100) self.assertEqual(obj.call_count, 2) - def test_cache_logcontexts(self): + def test_cache_logcontexts(self) -> Deferred: """Check that logcontexts are set and restored correctly when using the cache.""" - complete_lookup: Deferred = defer.Deferred() + complete_lookup: Deferred = Deferred() class Cls: @descriptors.cached() - def fn(self, arg1): + def fn(self, arg1: int) -> "Deferred[int]": @defer.inlineCallbacks - def inner_fn(): + def inner_fn() -> Generator["Deferred[object]", object, int]: with PreserveLoggingContext(): yield complete_lookup return 1 @@ -278,13 +290,13 @@ class DescriptorTestCase(unittest.TestCase): return inner_fn() @defer.inlineCallbacks - def do_lookup(): + def do_lookup() -> Generator["Deferred[Any]", object, int]: with LoggingContext("c1") as c1: r = yield obj.fn(1) self.assertEqual(current_context(), c1) - return r + return cast(int, r) - def check_result(r): + def check_result(r: int) -> None: self.assertEqual(r, 1) obj = Cls() @@ -304,15 +316,15 @@ class DescriptorTestCase(unittest.TestCase): return defer.gatherResults([d1, d2]) - def test_cache_logcontexts_with_exception(self): + def test_cache_logcontexts_with_exception(self) -> "Deferred[None]": """Check that the cache sets and restores logcontexts correctly when the lookup function throws an exception""" class Cls: @descriptors.cached() - def fn(self, arg1): + def fn(self, arg1: int) -> Deferred: @defer.inlineCallbacks - def inner_fn(): + def inner_fn() -> Generator["Deferred[Any]", object, NoReturn]: # we want this to behave like an asynchronous function yield run_on_reactor() raise SynapseError(400, "blah") @@ -320,7 +332,7 @@ class DescriptorTestCase(unittest.TestCase): return inner_fn() @defer.inlineCallbacks - def do_lookup(): + def do_lookup() -> Generator["Deferred[object]", object, None]: with LoggingContext("c1") as c1: try: d = obj.fn(1) @@ -347,13 +359,13 @@ class DescriptorTestCase(unittest.TestCase): return d1 @defer.inlineCallbacks - def test_cache_default_args(self): + def test_cache_default_args(self) -> Generator["Deferred[Any]", object, None]: class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1, arg2=2, arg3=3): + def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str: return self.mock(arg1, arg2, arg3) obj = Cls() @@ -384,27 +396,27 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(r, "chips") obj.mock.assert_not_called() - def test_cache_iterable(self): + def test_cache_iterable(self) -> None: class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached(iterable=True) - def fn(self, arg1, arg2): + def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: return self.mock(arg1, arg2) obj = Cls() - obj.mock.return_value = ["spam", "eggs"] + obj.mock.return_value = ("spam", "eggs") r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) + self.assertEqual(r.result, ("spam", "eggs")) obj.mock.assert_called_once_with(1, 2) obj.mock.reset_mock() # a call with different params should call the mock again - obj.mock.return_value = ["chips"] + obj.mock.return_value = ("chips",) r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) + self.assertEqual(r.result, ("chips",)) obj.mock.assert_called_once_with(1, 3) obj.mock.reset_mock() @@ -412,17 +424,17 @@ class DescriptorTestCase(unittest.TestCase): self.assertEqual(len(obj.fn.cache.cache), 3) r = obj.fn(1, 2) - self.assertEqual(r.result, ["spam", "eggs"]) + self.assertEqual(r.result, ("spam", "eggs")) r = obj.fn(1, 3) - self.assertEqual(r.result, ["chips"]) + self.assertEqual(r.result, ("chips",)) obj.mock.assert_not_called() - def test_cache_iterable_with_sync_exception(self): + def test_cache_iterable_with_sync_exception(self) -> None: """If the wrapped function throws synchronously, things should continue to work""" class Cls: @descriptors.cached(iterable=True) - def fn(self, arg1): + def fn(self, arg1: int) -> NoReturn: raise SynapseError(100, "mai spoon iz too big!!1") obj = Cls() @@ -438,20 +450,20 @@ class DescriptorTestCase(unittest.TestCase): d = obj.fn(1) self.failureResultOf(d, SynapseError) - def test_invalidate_cascade(self): + def test_invalidate_cascade(self) -> None: """Invalidations should cascade up through cache contexts""" class Cls: @cached(cache_context=True) - async def func1(self, key, cache_context): + async def func1(self, key: str, cache_context: _CacheContext) -> int: return await self.func2(key, on_invalidate=cache_context.invalidate) @cached(cache_context=True) - async def func2(self, key, cache_context): + async def func2(self, key: str, cache_context: _CacheContext) -> int: return await self.func3(key, on_invalidate=cache_context.invalidate) @cached(cache_context=True) - async def func3(self, key, cache_context): + async def func3(self, key: str, cache_context: _CacheContext) -> int: self.invalidate = cache_context.invalidate return 42 @@ -463,13 +475,13 @@ class DescriptorTestCase(unittest.TestCase): obj.invalidate() top_invalidate.assert_called_once() - def test_cancel(self): + def test_cancel(self) -> None: """Test that cancelling a lookup does not cancel other lookups""" complete_lookup: "Deferred[None]" = Deferred() class Cls: @cached() - async def fn(self, arg1): + async def fn(self, arg1: int) -> str: await complete_lookup return str(arg1) @@ -488,7 +500,7 @@ class DescriptorTestCase(unittest.TestCase): self.failureResultOf(d1, CancelledError) self.assertEqual(d2.result, "123") - def test_cancel_logcontexts(self): + def test_cancel_logcontexts(self) -> None: """Test that cancellation does not break logcontexts. * The `CancelledError` must be raised with the correct logcontext. @@ -501,14 +513,14 @@ class DescriptorTestCase(unittest.TestCase): inner_context_was_finished = False @cached() - async def fn(self, arg1): + async def fn(self, arg1: int) -> str: await make_deferred_yieldable(complete_lookup) self.inner_context_was_finished = current_context().finished return str(arg1) obj = Cls() - async def do_lookup(): + async def do_lookup() -> None: with LoggingContext("c1") as c1: try: await obj.fn(123) @@ -542,10 +554,10 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): """ @defer.inlineCallbacks - def test_passthrough(self): + def test_passthrough(self) -> Generator["Deferred[Any]", object, None]: class A: @cached() - def func(self, key): + def func(self, key: str) -> str: return key a = A() @@ -554,12 +566,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual((yield a.func("bar")), "bar") @defer.inlineCallbacks - def test_hit(self): + def test_hit(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] class A: @cached() - def func(self, key): + def func(self, key: str) -> str: callcount[0] += 1 return key @@ -572,12 +584,12 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual(callcount[0], 1) @defer.inlineCallbacks - def test_invalidate(self): + def test_invalidate(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] class A: @cached() - def func(self, key): + def func(self, key: str) -> str: callcount[0] += 1 return key @@ -592,48 +604,48 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual(callcount[0], 2) - def test_invalidate_missing(self): + def test_invalidate_missing(self) -> None: class A: @cached() - def func(self, key): + def func(self, key: str) -> str: return key A().func.invalidate(("what",)) @defer.inlineCallbacks - def test_max_entries(self): + def test_max_entries(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] class A: @cached(max_entries=10) - def func(self, key): + def func(self, key: int) -> int: callcount[0] += 1 return key a = A() - for k in range(0, 12): + for k in range(12): yield a.func(k) self.assertEqual(callcount[0], 12) # There must have been at least 2 evictions, meaning if we calculate # all 12 values again, we must get called at least 2 more times - for k in range(0, 12): + for k in range(12): yield a.func(k) self.assertTrue( callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0]) ) - def test_prefill(self): + def test_prefill(self) -> None: callcount = [0] d = defer.succeed(123) class A: @cached() - def func(self, key): + def func(self, key: str) -> "Deferred[int]": callcount[0] += 1 return d @@ -645,18 +657,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual(callcount[0], 0) @defer.inlineCallbacks - def test_invalidate_context(self): + def test_invalidate_context(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] callcount2 = [0] class A: @cached() - def func(self, key): + def func(self, key: str) -> str: callcount[0] += 1 return key @cached(cache_context=True) - def func2(self, key, cache_context): + def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]": callcount2[0] += 1 return self.func(key, on_invalidate=cache_context.invalidate) @@ -678,18 +690,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual(callcount2[0], 2) @defer.inlineCallbacks - def test_eviction_context(self): + def test_eviction_context(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] callcount2 = [0] class A: @cached(max_entries=2) - def func(self, key): + def func(self, key: str) -> str: callcount[0] += 1 return key @cached(cache_context=True) - def func2(self, key, cache_context): + def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]": callcount2[0] += 1 return self.func(key, on_invalidate=cache_context.invalidate) @@ -715,18 +727,18 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEqual(callcount2[0], 3) @defer.inlineCallbacks - def test_double_get(self): + def test_double_get(self) -> Generator["Deferred[Any]", object, None]: callcount = [0] callcount2 = [0] class A: @cached() - def func(self, key): + def func(self, key: str) -> str: callcount[0] += 1 return key @cached(cache_context=True) - def func2(self, key, cache_context): + def func2(self, key: str, cache_context: _CacheContext) -> "Deferred[str]": callcount2[0] += 1 return self.func(key, on_invalidate=cache_context.invalidate) @@ -763,17 +775,19 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): class CachedListDescriptorTestCase(unittest.TestCase): @defer.inlineCallbacks - def test_cache(self): + def test_cache(self) -> Generator["Deferred[Any]", object, None]: class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1, arg2): + def fn(self, arg1: int, arg2: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1, arg2): + async def list_fn( + self, args1: Iterable[int], arg2: int + ) -> Mapping[int, str]: context = current_context() assert isinstance(context, LoggingContext) assert context.name == "c1" @@ -824,23 +838,23 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.mock.assert_called_once_with({40}, 2) self.assertEqual(r, {10: "fish", 40: "gravy"}) - def test_concurrent_lookups(self): + def test_concurrent_lookups(self) -> None: """All concurrent lookups should get the same result""" class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1): + def fn(self, arg1: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - def list_fn(self, args1) -> "Deferred[dict]": + def list_fn(self, args1: List[int]) -> "Deferred[Mapping[int, str]]": return self.mock(args1) obj = Cls() - deferred_result: "Deferred[dict]" = Deferred() + deferred_result: "Deferred[Mapping[int, str]]" = Deferred() obj.mock.return_value = deferred_result # start off several concurrent lookups of the same key @@ -867,19 +881,19 @@ class CachedListDescriptorTestCase(unittest.TestCase): self.assertEqual(self.successResultOf(d3), {10: "peas"}) @defer.inlineCallbacks - def test_invalidate(self): + def test_invalidate(self) -> Generator["Deferred[Any]", object, None]: """Make sure that invalidation callbacks are called.""" class Cls: - def __init__(self): + def __init__(self) -> None: self.mock = mock.Mock() @descriptors.cached() - def fn(self, arg1, arg2): + def fn(self, arg1: int, arg2: int) -> None: pass @descriptors.cachedList(cached_method_name="fn", list_name="args1") - async def list_fn(self, args1, arg2): + async def list_fn(self, args1: List[int], arg2: int) -> Mapping[int, str]: # we want this to behave like an asynchronous function await run_on_reactor() return self.mock(args1, arg2) @@ -908,17 +922,17 @@ class CachedListDescriptorTestCase(unittest.TestCase): invalidate0.assert_called_once() invalidate1.assert_called_once() - def test_cancel(self): + def test_cancel(self) -> None: """Test that cancelling a lookup does not cancel other lookups""" complete_lookup: "Deferred[None]" = Deferred() class Cls: @cached() - def fn(self, arg1): + def fn(self, arg1: int) -> None: pass @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args): + async def list_fn(self, args: List[int]) -> Mapping[int, str]: await complete_lookup return {arg: str(arg) for arg in args} @@ -936,7 +950,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): self.failureResultOf(d1, CancelledError) self.assertEqual(d2.result, {123: "123", 456: "456", 789: "789"}) - def test_cancel_logcontexts(self): + def test_cancel_logcontexts(self) -> None: """Test that cancellation does not break logcontexts. * The `CancelledError` must be raised with the correct logcontext. @@ -949,18 +963,18 @@ class CachedListDescriptorTestCase(unittest.TestCase): inner_context_was_finished = False @cached() - def fn(self, arg1): + def fn(self, arg1: int) -> None: pass @cachedList(cached_method_name="fn", list_name="args") - async def list_fn(self, args): + async def list_fn(self, args: List[int]) -> Mapping[int, str]: await make_deferred_yieldable(complete_lookup) self.inner_context_was_finished = current_context().finished return {arg: str(arg) for arg in args} obj = Cls() - async def do_lookup(): + async def do_lookup() -> None: with LoggingContext("c1") as c1: try: await obj.list_fn([123]) @@ -983,7 +997,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): ) self.assertEqual(current_context(), SENTINEL_CONTEXT) - def test_num_args_mismatch(self): + def test_num_args_mismatch(self) -> None: """ Make sure someone does not accidentally use @cachedList on a method with a mismatch in the number args to the underlying single cache method. @@ -991,14 +1005,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): class Cls: @descriptors.cached(tree=True) - def fn(self, room_id, event_id): + def fn(self, room_id: str, event_id: str) -> None: pass # This is wrong ❌. `@cachedList` expects to be given the same number # of arguments as the underlying cached function, just with one of # the arguments being an iterable @descriptors.cachedList(cached_method_name="fn", list_name="keys") - def list_fn(self, keys: Iterable[Tuple[str, str]]): + def list_fn(self, keys: Iterable[Tuple[str, str]]) -> None: pass # Corrected syntax ✅ diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 91cac9822a..05983ed434 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -60,11 +60,9 @@ class ObservableDeferredTest(TestCase): observer1.addBoth(check_called_first) # store the results - results: List[Optional[ObservableDeferred[int]]] = [None, None] + results: List[Optional[int]] = [None, None] - def check_val( - res: ObservableDeferred[int], idx: int - ) -> ObservableDeferred[int]: + def check_val(res: int, idx: int) -> int: results[idx] = res return res @@ -93,14 +91,14 @@ class ObservableDeferredTest(TestCase): observer1.addBoth(check_called_first) # store the results - results: List[Optional[ObservableDeferred[str]]] = [None, None] + results: List[Optional[Failure]] = [None, None] - def check_val(res: ObservableDeferred[str], idx: int) -> None: + def check_failure(res: Failure, idx: int) -> None: results[idx] = res return None - observer1.addErrback(check_val, 0) - observer2.addErrback(check_val, 1) + observer1.addErrback(check_failure, 0) + observer2.addErrback(check_failure, 1) try: raise Exception("gah!") diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 5f8f4e76b5..4bcd17a6fc 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -11,12 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.util.retryutils import ( - MIN_RETRY_INTERVAL, - RETRY_MULTIPLIER, - NotRetryingDestination, - get_retry_limiter, -) +from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from tests.unittest import HomeserverTestCase @@ -42,6 +37,11 @@ class RetryLimiterTestCase(HomeserverTestCase): limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + min_retry_interval_ms = ( + self.hs.config.federation.destination_min_retry_interval_ms + ) + retry_multiplier = self.hs.config.federation.destination_retry_multiplier + self.pump(1) try: with limiter: @@ -57,7 +57,7 @@ class RetryLimiterTestCase(HomeserverTestCase): assert new_timings is not None self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.retry_last_ts, failure_ts) - self.assertEqual(new_timings.retry_interval, MIN_RETRY_INTERVAL) + self.assertEqual(new_timings.retry_interval, min_retry_interval_ms) # now if we try again we should get a failure self.get_failure( @@ -68,7 +68,7 @@ class RetryLimiterTestCase(HomeserverTestCase): # advance the clock and try again # - self.pump(MIN_RETRY_INTERVAL) + self.pump(min_retry_interval_ms) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) @@ -87,16 +87,16 @@ class RetryLimiterTestCase(HomeserverTestCase): self.assertEqual(new_timings.failure_ts, failure_ts) self.assertEqual(new_timings.retry_last_ts, retry_ts) self.assertGreaterEqual( - new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 0.5 + new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 0.5 ) self.assertLessEqual( - new_timings.retry_interval, MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0 + new_timings.retry_interval, min_retry_interval_ms * retry_multiplier * 2.0 ) # # one more go, with success # - self.reactor.advance(MIN_RETRY_INTERVAL * RETRY_MULTIPLIER * 2.0) + self.reactor.advance(min_retry_interval_ms * retry_multiplier * 2.0) limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) self.pump(1) @@ -108,3 +108,54 @@ class RetryLimiterTestCase(HomeserverTestCase): new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) + + def test_max_retry_interval(self) -> None: + """Test that `destination_max_retry_interval` setting works as expected""" + store = self.hs.get_datastores().main + + destination_max_retry_interval_ms = ( + self.hs.config.federation.destination_max_retry_interval_ms + ) + + self.get_success(get_retry_limiter("test_dest", self.clock, store)) + self.pump(1) + + failure_ts = self.clock.time_msec() + + # Simulate reaching destination_max_retry_interval + self.get_success( + store.set_destination_retry_timings( + "test_dest", + failure_ts=failure_ts, + retry_last_ts=failure_ts, + retry_interval=destination_max_retry_interval_ms, + ) + ) + + # Check it fails + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) + + # Get past retry_interval and we can try again, and still throw an error to continue the backoff + self.reactor.advance(destination_max_retry_interval_ms / 1000 + 1) + limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) + self.pump(1) + try: + with limiter: + self.pump(1) + raise AssertionError("argh") + except AssertionError: + pass + + self.pump() + + # retry_interval does not increase and stays at destination_max_retry_interval_ms + new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) + assert new_timings is not None + self.assertEqual(new_timings.retry_interval, destination_max_retry_interval_ms) + + # Check it fails + self.get_failure( + get_retry_limiter("test_dest", self.clock, store), NotRetryingDestination + ) diff --git a/tests/util/test_task_scheduler.py b/tests/util/test_task_scheduler.py new file mode 100644 index 0000000000..8665aeb50c --- /dev/null +++ b/tests/util/test_task_scheduler.py @@ -0,0 +1,208 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple + +from twisted.internet.task import deferLater +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.types import JsonMapping, ScheduledTask, TaskStatus +from synapse.util import Clock +from synapse.util.task_scheduler import TaskScheduler + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.unittest import HomeserverTestCase, override_config + + +class TestTaskScheduler(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.task_scheduler = hs.get_task_scheduler() + self.task_scheduler.register_action(self._test_task, "_test_task") + self.task_scheduler.register_action(self._sleeping_task, "_sleeping_task") + self.task_scheduler.register_action(self._raising_task, "_raising_task") + self.task_scheduler.register_action(self._resumable_task, "_resumable_task") + + async def _test_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + # This test task will copy the parameters to the result + result = None + if task.params: + result = task.params + return (TaskStatus.COMPLETE, result, None) + + def test_schedule_task(self) -> None: + """Schedule a task in the future with some parameters to be copied as a result and check it executed correctly. + Also check that it get removed after `KEEP_TASKS_FOR_MS`.""" + timestamp = self.clock.time_msec() + 30 * 1000 + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_test_task", + timestamp=timestamp, + params={"val": 1}, + ) + ) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.SCHEDULED) + self.assertIsNone(task.result) + + # The timestamp being 30s after now the task should been executed + # after the first scheduling loop is run + self.reactor.advance(TaskScheduler.SCHEDULE_INTERVAL_MS / 1000) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) + assert task.result is not None + # The passed parameter should have been copied to the result + self.assertTrue(task.result.get("val") == 1) + + # Let's wait for the complete task to be deleted and hence unavailable + self.reactor.advance((TaskScheduler.KEEP_TASKS_FOR_MS / 1000) + 1) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + self.assertIsNone(task) + + async def _sleeping_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + # Sleep for a second + await deferLater(self.reactor, 1, lambda: None) + return TaskStatus.COMPLETE, None, None + + def test_schedule_lot_of_tasks(self) -> None: + """Schedule more than `TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS` tasks and check the behavior.""" + task_ids = [] + for i in range(TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS + 1): + task_ids.append( + self.get_success( + self.task_scheduler.schedule_task( + "_sleeping_task", + params={"val": i}, + ) + ) + ) + + # This is to give the time to the active tasks to finish + self.reactor.advance(1) + + # Check that only MAX_CONCURRENT_RUNNING_TASKS tasks has run and that one + # is still scheduled. + tasks = [ + self.get_success(self.task_scheduler.get_task(task_id)) + for task_id in task_ids + ] + + self.assertEquals( + len( + [t for t in tasks if t is not None and t.status == TaskStatus.COMPLETE] + ), + TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS, + ) + + scheduled_tasks = [ + t for t in tasks if t is not None and t.status == TaskStatus.ACTIVE + ] + self.assertEquals(len(scheduled_tasks), 1) + + # We need to wait for the next run of the scheduler loop + self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + self.reactor.advance(1) + + # Check that the last task has been properly executed after the next scheduler loop run + prev_scheduled_task = self.get_success( + self.task_scheduler.get_task(scheduled_tasks[0].id) + ) + assert prev_scheduled_task is not None + self.assertEquals( + prev_scheduled_task.status, + TaskStatus.COMPLETE, + ) + + async def _raising_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + raise Exception("raising") + + def test_schedule_raising_task(self) -> None: + """Schedule a task raising an exception and check it runs to failure and report exception content.""" + task_id = self.get_success(self.task_scheduler.schedule_task("_raising_task")) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.FAILED) + self.assertEqual(task.error, "raising") + + async def _resumable_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + if task.result and "in_progress" in task.result: + return TaskStatus.COMPLETE, {"success": True}, None + else: + await self.task_scheduler.update_task(task.id, result={"in_progress": True}) + # Await forever to simulate an aborted task because of a restart + await deferLater(self.reactor, 2**16, lambda: None) + # This should never been called + return TaskStatus.ACTIVE, None, None + + def test_schedule_resumable_task(self) -> None: + """Schedule a resumable task and check that it gets properly resumed and complete after simulating a synapse restart.""" + task_id = self.get_success(self.task_scheduler.schedule_task("_resumable_task")) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.ACTIVE) + + # Simulate a synapse restart by emptying the list of running tasks + self.task_scheduler._running_tasks = set() + self.reactor.advance((TaskScheduler.SCHEDULE_INTERVAL_MS / 1000)) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) + assert task.result is not None + self.assertTrue(task.result.get("success")) + + +class TestTaskSchedulerWithBackgroundWorker(BaseMultiWorkerStreamTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.task_scheduler = hs.get_task_scheduler() + self.task_scheduler.register_action(self._test_task, "_test_task") + + async def _test_task( + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + return (TaskStatus.COMPLETE, None, None) + + @override_config({"run_background_tasks_on": "worker1"}) + def test_schedule_task(self) -> None: + """Check that a task scheduled to run now is launch right away on the background worker.""" + bg_worker_hs = self.make_worker_hs( + "synapse.app.generic_worker", + extra_config={"worker_name": "worker1"}, + ) + bg_worker_hs.get_task_scheduler().register_action(self._test_task, "_test_task") + + task_id = self.get_success( + self.task_scheduler.schedule_task( + "_test_task", + ) + ) + + task = self.get_success(self.task_scheduler.get_task(task_id)) + assert task is not None + self.assertEqual(task.status, TaskStatus.COMPLETE) |