summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_appservice.py122
-rw-r--r--tests/handlers/test_presence.py163
-rw-r--r--tests/handlers/test_space_summary.py185
-rw-r--r--tests/module_api/test_api.py10
-rw-r--r--tests/push/test_email.py20
-rw-r--r--tests/rest/client/v1/test_rooms.py92
-rw-r--r--tests/rest/client/v2_alpha/test_account.py33
-rw-r--r--tests/rest/client/v2_alpha/test_register.py12
-rw-r--r--tests/storage/databases/main/test_events_worker.py50
-rw-r--r--tests/test_federation.py6
10 files changed, 580 insertions, 113 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py

index 024c5e963c..43998020b2 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py
@@ -133,11 +133,131 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.assertEquals(result.room_id, room_id) self.assertEquals(result.servers, servers) - def _mkservice(self, is_interested): + def test_get_3pe_protocols_no_appservices(self): + self.mock_store.get_app_services.return_value = [] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_no_protocols(self): + service = self._mkservice(False, []) + self.mock_store.get_app_services.return_value = [service] + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_not_called() + self.assertEquals(response, {}) + + def test_get_3pe_protocols_protocol_no_response(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals(response, {}) + + def test_get_3pe_protocols_select_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol")) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_one_protocol(self): + service = self._mkservice(False, ["my-protocol"]) + self.mock_store.get_app_services.return_value = [service] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called_once_with( + service, "my-protocol" + ) + self.assertEquals( + response, {"my-protocol": {"x-protocol-data": 42, "instances": []}} + ) + + def test_get_3pe_protocols_multiple_protocol(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["other-protocol"]) + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol.return_value = make_awaitable( + {"x-protocol-data": 42, "instances": []} + ) + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + self.mock_as_api.get_3pe_protocol.assert_called() + self.assertEquals( + response, + { + "my-protocol": {"x-protocol-data": 42, "instances": []}, + "other-protocol": {"x-protocol-data": 42, "instances": []}, + }, + ) + + def test_get_3pe_protocols_multiple_info(self): + service_one = self._mkservice(False, ["my-protocol"]) + service_two = self._mkservice(False, ["my-protocol"]) + + async def get_3pe_protocol(service, unusedProtocol): + if service == service_one: + return { + "x-protocol-data": 42, + "instances": [{"desc": "Alice's service"}], + } + if service == service_two: + return { + "x-protocol-data": 36, + "x-not-used": 45, + "instances": [{"desc": "Bob's service"}], + } + raise Exception("Unexpected service") + + self.mock_store.get_app_services.return_value = [service_one, service_two] + self.mock_as_api.get_3pe_protocol = get_3pe_protocol + response = self.successResultOf( + defer.ensureDeferred(self.handler.get_3pe_protocols()) + ) + # It's expected that the second service's data doesn't appear in the response + self.assertEquals( + response, + { + "my-protocol": { + "x-protocol-data": 42, + "instances": [ + { + "desc": "Alice's service", + }, + {"desc": "Bob's service"}, + ], + }, + }, + ) + + def _mkservice(self, is_interested, protocols=None): service = Mock() service.is_interested.return_value = make_awaitable(is_interested) service.token = "mock_service_token" service.url = "mock_service_url" + service.protocols = protocols return service def _mkservice_alias(self, is_interested_in_alias): diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 18e92e90d7..29845a80da 100644 --- a/tests/handlers/test_presence.py +++ b/tests/handlers/test_presence.py
@@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Optional from unittest.mock import Mock, call from signedjson.key import generate_signing_key @@ -339,8 +339,11 @@ class PresenceUpdateTestCase(unittest.HomeserverTestCase): class PresenceTimeoutTestCase(unittest.TestCase): + """Tests different timers and that the timer does not change `status_msg` of user.""" + def test_idle_timer(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -348,12 +351,14 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.UNAVAILABLE) + self.assertEquals(new_state.status_msg, status_msg) def test_busy_no_idle(self): """ @@ -361,6 +366,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): presence state into unavailable. """ user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -368,15 +374,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.BUSY, last_active_ts=now - IDLE_TIMER - 1, last_user_sync_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.BUSY) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -384,15 +393,18 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=0, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_sync_online(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -400,6 +412,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): state=PresenceState.ONLINE, last_active_ts=now - SYNC_ONLINE_TIMEOUT - 1, last_user_sync_ts=now - SYNC_ONLINE_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -408,9 +421,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.ONLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_federation_ping(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -419,12 +434,13 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_PING_INTERVAL - 1, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) self.assertIsNotNone(new_state) - self.assertEquals(new_state, new_state) + self.assertEquals(state, new_state) def test_no_timeout(self): user_id = "@foo:bar" @@ -444,6 +460,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): def test_federation_timeout(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -452,6 +469,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now, last_user_sync_ts=now, last_federation_update_ts=now - FEDERATION_TIMEOUT - 1, + status_msg=status_msg, ) new_state = handle_timeout( @@ -460,9 +478,11 @@ class PresenceTimeoutTestCase(unittest.TestCase): self.assertIsNotNone(new_state) self.assertEquals(new_state.state, PresenceState.OFFLINE) + self.assertEquals(new_state.status_msg, status_msg) def test_last_active(self): user_id = "@foo:bar" + status_msg = "I'm here!" now = 5000000 state = UserPresenceState.default(user_id) @@ -471,6 +491,7 @@ class PresenceTimeoutTestCase(unittest.TestCase): last_active_ts=now - LAST_ACTIVE_GRANULARITY - 1, last_user_sync_ts=now, last_federation_update_ts=now, + status_msg=status_msg, ) new_state = handle_timeout(state, is_mine=True, syncing_user_ids=set(), now=now) @@ -516,6 +537,144 @@ class PresenceHandlerTestCase(unittest.HomeserverTestCase): ) self.assertEqual(state.state, PresenceState.OFFLINE) + def test_user_goes_offline_by_timeout_status_msg_remain(self): + """Test that if a user doesn't update the records for a while + users presence goes `OFFLINE` because of timeout and `status_msg` remains. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Check that if we wait a while without telling the handler the user has + # stopped syncing that their presence state doesn't get timed out. + self.reactor.advance(SYNC_ONLINE_TIMEOUT / 2) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, status_msg) + + # Check that if the timeout fires, then the syncing user gets timed out + self.reactor.advance(SYNC_ONLINE_TIMEOUT) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, status_msg) + + def test_user_goes_offline_manually_with_no_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.OFFLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(state.state, PresenceState.OFFLINE) + self.assertEqual(state.status_msg, None) + + def test_user_goes_offline_manually_with_status_msg(self): + """Test that if a user change presence manually to `OFFLINE` + and a status is set, that `status_msg` appears. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as offline + self._set_presencestate_with_status_msg( + user_id, PresenceState.OFFLINE, "And now here." + ) + + def test_user_reset_online_with_no_status(self): + """Test that if a user set again the presence manually + and no status is set, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online again + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), {"presence": PresenceState.ONLINE} + ) + ) + + state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + # status_msg should remain even after going offline + self.assertEqual(state.state, PresenceState.ONLINE) + self.assertEqual(state.status_msg, None) + + def test_set_presence_with_status_msg_none(self): + """Test that if a user set again the presence manually + and status is `None`, that `status_msg` is `None`. + """ + user_id = "@test:server" + status_msg = "I'm here!" + + # Mark user as online + self._set_presencestate_with_status_msg( + user_id, PresenceState.ONLINE, status_msg + ) + + # Mark user as online and `status_msg = None` + self._set_presencestate_with_status_msg(user_id, PresenceState.ONLINE, None) + + def _set_presencestate_with_status_msg( + self, user_id: str, state: PresenceState, status_msg: Optional[str] + ): + """Set a PresenceState and status_msg and check the result. + + Args: + user_id: User for that the status is to be set. + PresenceState: The new PresenceState. + status_msg: Status message that is to be set. + """ + self.get_success( + self.presence_handler.set_state( + UserID.from_string(user_id), + {"presence": state, "status_msg": status_msg}, + ) + ) + + new_state = self.get_success( + self.presence_handler.get_state(UserID.from_string(user_id)) + ) + self.assertEqual(new_state.state, state) + self.assertEqual(new_state.status_msg, status_msg) + class PresenceFederationQueueTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index 01975c13d4..6cc1a02e12 100644 --- a/tests/handlers/test_space_summary.py +++ b/tests/handlers/test_space_summary.py
@@ -26,7 +26,7 @@ from synapse.api.constants import ( from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict -from synapse.handlers.space_summary import _child_events_comparison_key +from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry from synapse.rest import admin from synapse.rest.client.v1 import login, room from synapse.server import HomeServer @@ -351,26 +351,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): # events before child events). # Note that these entries are brief, but should contain enough info. - rooms = [ - { - "room_id": subspace, - "world_readable": True, - "room_type": RoomTypes.SPACE, - }, - { - "room_id": subroom, - "world_readable": True, - }, - ] - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": subroom, - "content": event_content, - }, + return [ + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + "room_type": RoomTypes.SPACE, + }, + [ + { + "room_id": subspace, + "state_key": subroom, + "content": {"via": [fed_hostname]}, + } + ], + ), + _RoomEntry( + subroom, + { + "room_id": subroom, + "world_readable": True, + }, + ), ] - return rooms, events # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) @@ -436,70 +440,95 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase): ): # Note that these entries are brief, but should contain enough info. rooms = [ - { - "room_id": public_room, - "world_readable": False, - "join_rules": JoinRules.PUBLIC, - }, - { - "room_id": knock_room, - "world_readable": False, - "join_rules": JoinRules.KNOCK, - }, - { - "room_id": not_invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": invited_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": restricted_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [], - }, - { - "room_id": restricted_accessible_room, - "world_readable": False, - "join_rules": JoinRules.RESTRICTED, - "allowed_spaces": [self.room], - }, - { - "room_id": world_readable_room, - "world_readable": True, - "join_rules": JoinRules.INVITE, - }, - { - "room_id": joined_room, - "world_readable": False, - "join_rules": JoinRules.INVITE, - }, - ] - - # Place each room in the sub-space. - event_content = {"via": [fed_hostname]} - events = [ - { - "room_id": subspace, - "state_key": room["room_id"], - "content": event_content, - } - for room in rooms + _RoomEntry( + public_room, + { + "room_id": public_room, + "world_readable": False, + "join_rules": JoinRules.PUBLIC, + }, + ), + _RoomEntry( + knock_room, + { + "room_id": knock_room, + "world_readable": False, + "join_rules": JoinRules.KNOCK, + }, + ), + _RoomEntry( + not_invited_room, + { + "room_id": not_invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + invited_room, + { + "room_id": invited_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + restricted_room, + { + "room_id": restricted_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [], + }, + ), + _RoomEntry( + restricted_accessible_room, + { + "room_id": restricted_accessible_room, + "world_readable": False, + "join_rules": JoinRules.RESTRICTED, + "allowed_spaces": [self.room], + }, + ), + _RoomEntry( + world_readable_room, + { + "room_id": world_readable_room, + "world_readable": True, + "join_rules": JoinRules.INVITE, + }, + ), + _RoomEntry( + joined_room, + { + "room_id": joined_room, + "world_readable": False, + "join_rules": JoinRules.INVITE, + }, + ), ] # Also include the subspace. rooms.insert( 0, - { - "room_id": subspace, - "world_readable": True, - }, + _RoomEntry( + subspace, + { + "room_id": subspace, + "world_readable": True, + }, + # Place each room in the sub-space. + [ + { + "room_id": subspace, + "state_key": room.room_id, + "content": {"via": [fed_hostname]}, + } + for room in rooms + ], + ), ) - return rooms, events + return rooms # Add a room to the space which is on another server. self._add_child(self.space, subspace, self.token) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 81d9e2f484..0b817cc701 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py
@@ -79,6 +79,16 @@ class ModuleApiTestCase(HomeserverTestCase): displayname = self.get_success(self.store.get_profile_displayname("bob")) self.assertEqual(displayname, "Bobberino") + def test_get_userinfo_by_id(self): + user_id = self.register_user("alice", "1234") + found_user = self.get_success(self.module_api.get_userinfo_by_id(user_id)) + self.assertEqual(found_user.user_id.to_string(), user_id) + self.assertIdentical(found_user.is_admin, False) + + def test_get_userinfo_by_id__no_user_found(self): + found_user = self.get_success(self.module_api.get_userinfo_by_id("@alice:test")) + self.assertIsNone(found_user) + def test_sending_events_into_room(self): """Tests that a module can send events into a room""" # Mock out create_and_send_nonmember_event to check whether events are being sent diff --git a/tests/push/test_email.py b/tests/push/test_email.py
index e04bc5c9a6..a487706758 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py
@@ -45,14 +45,6 @@ class EmailPusherTests(HomeserverTestCase): def make_homeserver(self, reactor, clock): - # List[Tuple[Deferred, args, kwargs]] - self.email_attempts = [] - - def sendmail(*args, **kwargs): - d = Deferred() - self.email_attempts.append((d, args, kwargs)) - return d - config = self.default_config() config["email"] = { "enable_notifs": True, @@ -75,7 +67,17 @@ class EmailPusherTests(HomeserverTestCase): config["public_baseurl"] = "aaa" config["start_pushers"] = True - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + # List[Tuple[Deferred, args, kwargs]] + self.email_attempts = [] + + def sendmail(*args, **kwargs): + d = Deferred() + self.email_attempts.append((d, args, kwargs)) + return d + + hs.get_send_email_handler()._sendmail = sendmail return hs diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 3df070c936..1a9528ec20 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py
@@ -19,11 +19,14 @@ import json from typing import Iterable -from unittest.mock import Mock +from unittest.mock import Mock, call from urllib import parse as urlparse +from twisted.internet import defer + import synapse.rest.admin from synapse.api.constants import EventContentFields, EventTypes, Membership +from synapse.api.errors import HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client.v1 import directory, login, profile, room @@ -1124,6 +1127,93 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) +class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): + """Test that we correctly fallback to local filtering if a remote server + doesn't support search. + """ + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + room.register_servlets, + login.register_servlets, + ] + + def make_homeserver(self, reactor, clock): + return self.setup_test_homeserver(federation_client=Mock()) + + def prepare(self, reactor, clock, hs): + self.register_user("user", "pass") + self.token = self.login("user", "pass") + + self.federation_client = hs.get_federation_client() + + def test_simple(self): + "Simple test for searching rooms over federation" + self.federation_client.get_public_rooms.side_effect = ( + lambda *a, **k: defer.succeed({}) + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_called_once_with( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ) + + def test_fallback(self): + "Test that searching public rooms over federation falls back if it gets a 404" + + # The `get_public_rooms` should be called again if the first call fails + # with a 404, when using search filters. + self.federation_client.get_public_rooms.side_effect = ( + HttpResponseException(404, "Not Found", b""), + defer.succeed({}), + ) + + search_filter = {"generic_search_term": "foobar"} + + channel = self.make_request( + "POST", + b"/_matrix/client/r0/publicRooms?server=testserv", + content={"filter": search_filter}, + access_token=self.token, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.federation_client.get_public_rooms.assert_has_calls( + [ + call( + "testserv", + limit=100, + since_token=None, + search_filter=search_filter, + include_all_networks=False, + third_party_instance_id=None, + ), + call( + "testserv", + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ), + ] + ) + + class PerRoomProfilesForbiddenTestCase(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py
index 317a2287e3..e7e617e9df 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py
@@ -47,12 +47,6 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - return - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -67,7 +61,16 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + hs.get_send_email_handler()._sendmail = sendmail + return hs def prepare(self, reactor, clock, hs): @@ -511,11 +514,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): config = self.default_config() # Email config. - self.email_attempts = [] - - async def sendmail(smtphost, from_addr, to_addrs, msg, **kwargs): - self.email_attempts.append(msg) - config["email"] = { "enable_notifs": False, "template_dir": os.path.abspath( @@ -530,7 +528,16 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "https://example.com" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail( + reactor, smtphost, smtpport, from_addr, to_addrs, msg, **kwargs + ): + self.email_attempts.append(msg) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail + return self.hs def prepare(self, reactor, clock, hs): diff --git a/tests/rest/client/v2_alpha/test_register.py b/tests/rest/client/v2_alpha/test_register.py
index 1cad5f00eb..a52e5e608a 100644 --- a/tests/rest/client/v2_alpha/test_register.py +++ b/tests/rest/client/v2_alpha/test_register.py
@@ -509,10 +509,6 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } # Email config. - self.email_attempts = [] - - async def sendmail(*args, **kwargs): - self.email_attempts.append((args, kwargs)) config["email"] = { "enable_notifs": True, @@ -532,7 +528,13 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase): } config["public_baseurl"] = "aaa" - self.hs = self.setup_test_homeserver(config=config, sendmail=sendmail) + self.hs = self.setup_test_homeserver(config=config) + + async def sendmail(*args, **kwargs): + self.email_attempts.append((args, kwargs)) + + self.email_attempts = [] + self.hs.get_send_email_handler()._sendmail = sendmail self.store = self.hs.get_datastore() diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 932970fd9a..d05d367685 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py
@@ -14,7 +14,10 @@ import json from synapse.logging.context import LoggingContext +from synapse.rest import admin +from synapse.rest.client.v1 import login, room from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.util.async_helpers import yieldable_gather_results from tests import unittest @@ -94,3 +97,50 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): res = self.get_success(self.store.have_seen_events("room1", ["event10"])) self.assertEquals(res, {"event10"}) self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + +class EventCacheTestCase(unittest.HomeserverTestCase): + """Test that the various layers of event cache works.""" + + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + + self.room = self.helper.create_room_as(self.user, tok=self.token) + + res = self.helper.send(self.room, tok=self.token) + self.event_id = res["event_id"] + + # Reset the event cache so the tests start with it empty + self.store._get_event_cache.clear() + + def test_simple(self): + """Test that we cache events that we pull from the DB.""" + + with LoggingContext("test") as ctx: + self.get_success(self.store.get_event(self.event_id)) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) + + def test_dedupe(self): + """Test that if we request the same event multiple times we only pull it + out once. + """ + + with LoggingContext("test") as ctx: + d = yieldable_gather_results( + self.store.get_event, [self.event_id, self.event_id] + ) + self.get_success(d) + + # We should have fetched the event from the DB + self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1) diff --git a/tests/test_federation.py b/tests/test_federation.py
index 0ed8326f55..3785799f46 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py
@@ -75,10 +75,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase): ) self.handler = self.homeserver.get_federation_handler() - self.handler._check_event_auth = ( - lambda origin, event, context, state, auth_events, backfilled: succeed( - context - ) + self.handler._check_event_auth = lambda origin, event, context, state, claimed_auth_event_map, backfilled: succeed( + context ) self.client = self.homeserver.get_federation_client() self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(