diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index ddcf3ee348..734ed84d78 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -13,8 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Iterable
from unittest import mock
+from parameterized import parameterized
from signedjson import key as key, sign as sign
from twisted.internet import defer
@@ -23,6 +25,7 @@ from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import Codes, SynapseError
from tests import unittest
+from tests.test_utils import make_awaitable
class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
@@ -765,6 +768,8 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_user_id = "@test:other"
local_user_id = "@test:test"
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
return_value=defer.succeed({"some_room_id"})
)
@@ -831,3 +836,94 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
}
},
)
+
+ @parameterized.expand(
+ [
+ # The remote homeserver's response indicates that this user has 0/1/2 devices.
+ ([],),
+ (["device_1"],),
+ (["device_1", "device_2"],),
+ ]
+ )
+ def test_query_all_devices_caches_result(self, device_ids: Iterable[str]):
+ """Test that requests for all of a remote user's devices are cached.
+
+ We do this by asserting that only one call over federation was made, and that
+ the two queries to the local homeserver produce the same response.
+ """
+ local_user_id = "@test:test"
+ remote_user_id = "@test:other"
+ request_body = {"device_keys": {remote_user_id: []}}
+
+ response_devices = [
+ {
+ "device_id": device_id,
+ "keys": {
+ "algorithms": ["dummy"],
+ "device_id": device_id,
+ "keys": {f"dummy:{device_id}": "dummy"},
+ "signatures": {device_id: {f"dummy:{device_id}": "dummy"}},
+ "unsigned": {},
+ "user_id": "@test:other",
+ },
+ }
+ for device_id in device_ids
+ ]
+
+ response_body = {
+ "devices": response_devices,
+ "user_id": remote_user_id,
+ "stream_id": 12345, # an integer, according to the spec
+ }
+
+ e2e_handler = self.hs.get_e2e_keys_handler()
+
+ # Pretend we're sharing a room with the user we're querying. If not,
+ # `_query_devices_for_destination` will return early.
+ mock_get_rooms = mock.patch.object(
+ self.store,
+ "get_rooms_for_user",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(["some_room_id"]),
+ )
+ mock_request = mock.patch.object(
+ self.hs.get_federation_client(),
+ "query_user_devices",
+ new_callable=mock.MagicMock,
+ return_value=make_awaitable(response_body),
+ )
+
+ with mock_get_rooms, mock_request as mocked_federation_request:
+ # Make the first query and sanity check it succeeds.
+ response_1 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_1["failures"], {})
+
+ # We should have made a federation request to do so.
+ mocked_federation_request.assert_called_once()
+
+ # Reset the mock so we can prove we don't make a second federation request.
+ mocked_federation_request.reset_mock()
+
+ # Repeat the query.
+ response_2 = self.get_success(
+ e2e_handler.query_devices(
+ request_body,
+ timeout=10,
+ from_user_id=local_user_id,
+ from_device_id="some_device_id",
+ )
+ )
+ self.assertEqual(response_2["failures"], {})
+
+ # We should not have made a second federation request.
+ mocked_federation_request.assert_not_called()
+
+ # The two requests to the local homeserver should be identical.
+ self.assertEqual(response_1, response_2)
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 08e9730d4d..2add72b28a 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -22,7 +22,7 @@ from twisted.internet import defer
import synapse
from synapse.handlers.auth import load_legacy_password_auth_providers
from synapse.module_api import ModuleApi
-from synapse.rest.client import devices, login
+from synapse.rest.client import devices, login, logout
from synapse.types import JsonDict
from tests import unittest
@@ -155,6 +155,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
synapse.rest.admin.register_servlets,
login.register_servlets,
devices.register_servlets,
+ logout.register_servlets,
]
def setUp(self):
@@ -719,6 +720,31 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
+ def test_on_logged_out(self):
+ """Tests that the on_logged_out callback is called when the user logs out."""
+ self.register_user("rin", "password")
+ tok = self.login("rin", "password")
+
+ self.called = False
+
+ async def on_logged_out(user_id, device_id, access_token):
+ self.called = True
+
+ on_logged_out = Mock(side_effect=on_logged_out)
+ self.hs.get_password_auth_provider().on_logged_out_callbacks.append(
+ on_logged_out
+ )
+
+ channel = self.make_request(
+ "POST",
+ "/_matrix/client/v3/logout",
+ {},
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200)
+ on_logged_out.assert_called_once()
+ self.assertTrue(self.called)
+
def _get_login_flows(self) -> JsonDict:
channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index e5a6a6c747..51b22d2998 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -28,6 +28,7 @@ from synapse.api.constants import (
from synapse.api.errors import AuthError, NotFoundError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
+from synapse.federation.transport.client import TransportLayerClient
from synapse.handlers.room_summary import _child_events_comparison_key, _RoomEntry
from synapse.rest import admin
from synapse.rest.client import login, room
@@ -134,10 +135,18 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
self._add_child(self.space, self.room, self.token)
def _add_child(
- self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+ self,
+ space_id: str,
+ room_id: str,
+ token: str,
+ order: Optional[str] = None,
+ via: Optional[List[str]] = None,
) -> None:
"""Add a child room to a space."""
- content: JsonDict = {"via": [self.hs.hostname]}
+ if via is None:
+ via = [self.hs.hostname]
+
+ content: JsonDict = {"via": via}
if order is not None:
content["order"] = order
self.helper.send_state(
@@ -253,6 +262,38 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_large_space(self):
+ """Test a space with a large number of rooms."""
+ rooms = [self.room]
+ # Make at least 51 rooms that are part of the space.
+ for _ in range(55):
+ room = self.helper.create_room_as(self.user, tok=self.token)
+ self._add_child(self.space, room, self.token)
+ rooms.append(room)
+
+ result = self.get_success(self.handler.get_space_summary(self.user, self.space))
+ # The spaces result should have the space and the first 50 rooms in it,
+ # along with the links from space -> room for those 50 rooms.
+ expected = [(self.space, rooms[:50])] + [(room, []) for room in rooms[:49]]
+ self._assert_rooms(result, expected)
+
+ # The result should have the space and the rooms in it, along with the links
+ # from space -> room.
+ expected = [(self.space, rooms)] + [(room, []) for room in rooms]
+
+ # Make two requests to fully paginate the results.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ result2 = self.get_success(
+ self.handler.get_room_hierarchy(
+ create_requester(self.user), self.space, from_token=result["next_batch"]
+ )
+ )
+ # Combine the results.
+ result["rooms"] += result2["rooms"]
+ self._assert_hierarchy(result, expected)
+
def test_visibility(self):
"""A user not in a space cannot inspect it."""
user2 = self.register_user("user2", "pass")
@@ -1004,6 +1045,85 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
)
self._assert_hierarchy(result, expected)
+ def test_fed_caching(self):
+ """
+ Federation `/hierarchy` responses should be cached.
+ """
+ fed_hostname = self.hs.hostname + "2"
+ fed_subspace = "#space:" + fed_hostname
+ fed_room = "#room:" + fed_hostname
+
+ # Add a room to the space which is on another server.
+ self._add_child(self.space, fed_subspace, self.token, via=[fed_hostname])
+
+ federation_requests = 0
+
+ async def get_room_hierarchy(
+ _self: TransportLayerClient,
+ destination: str,
+ room_id: str,
+ suggested_only: bool,
+ ) -> JsonDict:
+ nonlocal federation_requests
+ federation_requests += 1
+
+ return {
+ "room": {
+ "room_id": fed_subspace,
+ "world_readable": True,
+ "room_type": RoomTypes.SPACE,
+ "children_state": [
+ {
+ "type": EventTypes.SpaceChild,
+ "room_id": fed_subspace,
+ "state_key": fed_room,
+ "content": {"via": [fed_hostname]},
+ },
+ ],
+ },
+ "children": [
+ {
+ "room_id": fed_room,
+ "world_readable": True,
+ },
+ ],
+ "inaccessible_children": [],
+ }
+
+ expected = [
+ (self.space, [self.room, fed_subspace]),
+ (self.room, ()),
+ (fed_subspace, [fed_room]),
+ (fed_room, ()),
+ ]
+
+ with mock.patch(
+ "synapse.federation.transport.client.TransportLayerClient.get_room_hierarchy",
+ new=get_room_hierarchy,
+ ):
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # The previous federation response should be reused.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 1)
+ self._assert_hierarchy(result, expected)
+
+ # Expire the response cache
+ self.reactor.advance(5 * 60 + 1)
+
+ # A new federation request should be made.
+ result = self.get_success(
+ self.handler.get_room_hierarchy(create_requester(self.user), self.space)
+ )
+ self.assertEqual(federation_requests, 2)
+ self._assert_hierarchy(result, expected)
+
class RoomSummaryTestCase(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py
index 638186f173..07a760e91a 100644
--- a/tests/handlers/test_sync.py
+++ b/tests/handlers/test_sync.py
@@ -11,15 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from unittest.mock import Mock
+from unittest.mock import MagicMock, Mock, patch
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering
from synapse.api.room_versions import RoomVersions
-from synapse.handlers.sync import SyncConfig
+from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
@@ -27,6 +26,7 @@ from synapse.types import UserID, create_requester
import tests.unittest
import tests.utils
+from tests.test_utils import make_awaitable
class SyncTestCase(tests.unittest.HomeserverTestCase):
@@ -186,6 +186,97 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.assertNotIn(invite_room, [r.room_id for r in result.invited])
self.assertNotIn(knock_room, [r.room_id for r in result.knocked])
+ def test_ban_wins_race_with_join(self):
+ """Rooms shouldn't appear under "joined" if a join loses a race to a ban.
+
+ A complicated edge case. Imagine the following scenario:
+
+ * you attempt to join a room
+ * racing with that is a ban which comes in over federation, which ends up with
+ an earlier stream_ordering than the join.
+ * you get a sync response with a sync token which is _after_ the ban, but before
+ the join
+ * now your join lands; it is a valid event because its `prev_event`s predate the
+ ban, but will not make it into current_state_events (because bans win over
+ joins in state res, essentially).
+ * When we do a sync from the incremental sync, the only event in the timeline
+ is your join ... and yet you aren't joined.
+
+ The ban coming in over federation isn't crucial for this behaviour; the key
+ requirements are:
+ 1. the homeserver generates a join event with prev_events that precede the ban
+ (so that it passes the "are you banned" test)
+ 2. the join event has a stream_ordering after that of the ban.
+
+ We use monkeypatching to artificially trigger condition (1).
+ """
+ # A local user Alice creates a room.
+ owner = self.register_user("alice", "password")
+ owner_tok = self.login(owner, "password")
+ room_id = self.helper.create_room_as(owner, is_public=True, tok=owner_tok)
+
+ # Do a sync as Alice to get the latest event in the room.
+ alice_sync_result: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ create_requester(owner), generate_sync_config(owner)
+ )
+ )
+ self.assertEqual(len(alice_sync_result.joined), 1)
+ self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
+ last_room_creation_event_id = (
+ alice_sync_result.joined[0].timeline.events[-1].event_id
+ )
+
+ # Eve, a ne'er-do-well, registers.
+ eve = self.register_user("eve", "password")
+ eve_token = self.login(eve, "password")
+
+ # Alice preemptively bans Eve.
+ self.helper.ban(room_id, owner, eve, tok=owner_tok)
+
+ # Eve syncs.
+ eve_requester = create_requester(eve)
+ eve_sync_config = generate_sync_config(eve)
+ eve_sync_after_ban: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(eve_requester, eve_sync_config)
+ )
+
+ # Sanity check this sync result. We shouldn't be joined to the room.
+ self.assertEqual(eve_sync_after_ban.joined, [])
+
+ # Eve tries to join the room. We monkey patch the internal logic which selects
+ # the prev_events used when creating the join event, such that the ban does not
+ # precede the join.
+ mocked_get_prev_events = patch.object(
+ self.hs.get_datastore(),
+ "get_prev_events_for_room",
+ new_callable=MagicMock,
+ return_value=make_awaitable([last_room_creation_event_id]),
+ )
+ with mocked_get_prev_events:
+ self.helper.join(room_id, eve, tok=eve_token)
+
+ # Eve makes a second, incremental sync.
+ eve_incremental_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=eve_sync_after_ban.next_batch,
+ )
+ )
+ # Eve should not see herself as joined to the room.
+ self.assertEqual(eve_incremental_sync_after_join.joined, [])
+
+ # If we did a third initial sync, we should _still_ see eve is not joined to the room.
+ eve_initial_sync_after_join: SyncResult = self.get_success(
+ self.sync_handler.wait_for_sync_for_user(
+ eve_requester,
+ eve_sync_config,
+ since_token=None,
+ )
+ )
+ self.assertEqual(eve_initial_sync_after_join.joined, [])
+
_request_key = 0
|