summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/10527.misc2
-rw-r--r--changelog.d/10530.misc2
-rw-r--r--changelog.d/10549.feature1
-rw-r--r--synapse/handlers/space_summary.py201
-rw-r--r--synapse/rest/client/v1/room.py41
-rw-r--r--tests/handlers/test_space_summary.py386
6 files changed, 521 insertions, 112 deletions
diff --git a/changelog.d/10527.misc b/changelog.d/10527.misc
index 3cf22f9daf..ffc4e4289c 100644
--- a/changelog.d/10527.misc
+++ b/changelog.d/10527.misc
@@ -1 +1 @@
-Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)).
+Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).
diff --git a/changelog.d/10530.misc b/changelog.d/10530.misc
index 3cf22f9daf..ffc4e4289c 100644
--- a/changelog.d/10530.misc
+++ b/changelog.d/10530.misc
@@ -1 +1 @@
-Prepare for the new spaces summary endpoint (updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946)).
+Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).
diff --git a/changelog.d/10549.feature b/changelog.d/10549.feature
new file mode 100644
index 0000000000..ffc4e4289c
--- /dev/null
+++ b/changelog.d/10549.feature
@@ -0,0 +1 @@
+Add pagination to the spaces summary based on updates to [MSC2946](https://github.com/matrix-org/matrix-doc/pull/2946).
diff --git a/synapse/handlers/space_summary.py b/synapse/handlers/space_summary.py
index d04afe6c31..fd76c34695 100644
--- a/synapse/handlers/space_summary.py
+++ b/synapse/handlers/space_summary.py
@@ -18,7 +18,7 @@ import re
 from collections import deque
 from typing import (
     TYPE_CHECKING,
-    Collection,
+    Deque,
     Dict,
     Iterable,
     List,
@@ -38,9 +38,12 @@ from synapse.api.constants import (
     Membership,
     RoomTypes,
 )
+from synapse.api.errors import Codes, SynapseError
 from synapse.events import EventBase
 from synapse.events.utils import format_event_for_client_v2
 from synapse.types import JsonDict
+from synapse.util.caches.response_cache import ResponseCache
+from synapse.util.stringutils import random_string
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -57,6 +60,29 @@ MAX_ROOMS_PER_SPACE = 50
 MAX_SERVERS_PER_SPACE = 3
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PaginationKey:
+    """The key used to find unique pagination session."""
+
+    # The first three entries match the request parameters (and cannot change
+    # during a pagination session).
+    room_id: str
+    suggested_only: bool
+    max_depth: Optional[int]
+    # The randomly generated token.
+    token: str
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _PaginationSession:
+    """The information that is stored for pagination."""
+
+    # The queue of rooms which are still to process.
+    room_queue: Deque["_RoomQueueEntry"]
+    # A set of rooms which have been processed.
+    processed_rooms: Set[str]
+
+
 class SpaceSummaryHandler:
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
@@ -67,6 +93,21 @@ class SpaceSummaryHandler:
         self._server_name = hs.hostname
         self._federation_client = hs.get_federation_client()
 
+        # A map of query information to the current pagination state.
+        #
+        # TODO Allow for multiple workers to share this data.
+        # TODO Expire pagination tokens.
+        self._pagination_sessions: Dict[_PaginationKey, _PaginationSession] = {}
+
+        # If a user tries to fetch the same page multiple times in quick succession,
+        # only process the first attempt and return its result to subsequent requests.
+        self._pagination_response_cache: ResponseCache[
+            Tuple[str, bool, Optional[int], Optional[int], Optional[str]]
+        ] = ResponseCache(
+            hs.get_clock(),
+            "get_room_hierarchy",
+        )
+
     async def get_space_summary(
         self,
         requester: str,
@@ -130,7 +171,7 @@ class SpaceSummaryHandler:
                     requester, None, room_id, suggested_only, max_children
                 )
 
-                events: Collection[JsonDict] = []
+                events: Sequence[JsonDict] = []
                 if room_entry:
                     rooms_result.append(room_entry.room)
                     events = room_entry.children
@@ -207,6 +248,154 @@ class SpaceSummaryHandler:
 
         return {"rooms": rooms_result, "events": events_result}
 
+    async def get_room_hierarchy(
+        self,
+        requester: str,
+        requested_room_id: str,
+        suggested_only: bool = False,
+        max_depth: Optional[int] = None,
+        limit: Optional[int] = None,
+        from_token: Optional[str] = None,
+    ) -> JsonDict:
+        """
+        Implementation of the room hierarchy C-S API.
+
+        Args:
+            requester: The user ID of the user making this request.
+            requested_room_id: The room ID to start the hierarchy at (the "root" room).
+            suggested_only: Whether we should only return children with the "suggested"
+                flag set.
+            max_depth: The maximum depth in the tree to explore, must be a
+                non-negative integer.
+
+                0 would correspond to just the root room, 1 would include just
+                the root room's children, etc.
+            limit: An optional limit on the number of rooms to return per
+                page. Must be a positive integer.
+            from_token: An optional pagination token.
+
+        Returns:
+            The JSON hierarchy dictionary.
+        """
+        # If a user tries to fetch the same page multiple times in quick succession,
+        # only process the first attempt and return its result to subsequent requests.
+        #
+        # This is due to the pagination process mutating internal state, attempting
+        # to process multiple requests for the same page will result in errors.
+        return await self._pagination_response_cache.wrap(
+            (requested_room_id, suggested_only, max_depth, limit, from_token),
+            self._get_room_hierarchy,
+            requester,
+            requested_room_id,
+            suggested_only,
+            max_depth,
+            limit,
+            from_token,
+        )
+
+    async def _get_room_hierarchy(
+        self,
+        requester: str,
+        requested_room_id: str,
+        suggested_only: bool = False,
+        max_depth: Optional[int] = None,
+        limit: Optional[int] = None,
+        from_token: Optional[str] = None,
+    ) -> JsonDict:
+        """See docstring for SpaceSummaryHandler.get_room_hierarchy."""
+
+        # first of all, check that the user is in the room in question (or it's
+        # world-readable)
+        await self._auth.check_user_in_room_or_world_readable(
+            requested_room_id, requester
+        )
+
+        # If this is continuing a previous session, pull the persisted data.
+        if from_token:
+            pagination_key = _PaginationKey(
+                requested_room_id, suggested_only, max_depth, from_token
+            )
+            if pagination_key not in self._pagination_sessions:
+                raise SynapseError(400, "Unknown pagination token", Codes.INVALID_PARAM)
+
+            # Load the previous state.
+            pagination_session = self._pagination_sessions[pagination_key]
+            room_queue = pagination_session.room_queue
+            processed_rooms = pagination_session.processed_rooms
+        else:
+            # the queue of rooms to process
+            room_queue = deque((_RoomQueueEntry(requested_room_id, ()),))
+
+            # Rooms we have already processed.
+            processed_rooms = set()
+
+        rooms_result: List[JsonDict] = []
+
+        # Cap the limit to a server-side maximum.
+        if limit is None:
+            limit = MAX_ROOMS
+        else:
+            limit = min(limit, MAX_ROOMS)
+
+        # Iterate through the queue until we reach the limit or run out of
+        # rooms to include.
+        while room_queue and len(rooms_result) < limit:
+            queue_entry = room_queue.popleft()
+            room_id = queue_entry.room_id
+            current_depth = queue_entry.depth
+            if room_id in processed_rooms:
+                # already done this room
+                continue
+
+            logger.debug("Processing room %s", room_id)
+
+            is_in_room = await self._store.is_host_joined(room_id, self._server_name)
+            if is_in_room:
+                room_entry = await self._summarize_local_room(
+                    requester,
+                    None,
+                    room_id,
+                    suggested_only,
+                    # TODO Handle max children.
+                    max_children=None,
+                )
+
+                if room_entry:
+                    rooms_result.append(room_entry.as_json())
+
+                    # Add the child to the queue. We have already validated
+                    # that the vias are a list of server names.
+                    #
+                    # If the current depth is the maximum depth, do not queue
+                    # more entries.
+                    if max_depth is None or current_depth < max_depth:
+                        room_queue.extendleft(
+                            _RoomQueueEntry(
+                                ev["state_key"], ev["content"]["via"], current_depth + 1
+                            )
+                            for ev in reversed(room_entry.children)
+                        )
+
+                processed_rooms.add(room_id)
+            else:
+                # TODO Federation.
+                pass
+
+        result: JsonDict = {"rooms": rooms_result}
+
+        # If there's additional data, generate a pagination token (and persist state).
+        if room_queue:
+            next_token = random_string(24)
+            result["next_token"] = next_token
+            pagination_key = _PaginationKey(
+                requested_room_id, suggested_only, max_depth, next_token
+            )
+            self._pagination_sessions[pagination_key] = _PaginationSession(
+                room_queue, processed_rooms
+            )
+
+        return result
+
     async def federation_space_summary(
         self,
         origin: str,
@@ -652,6 +841,7 @@ class SpaceSummaryHandler:
 class _RoomQueueEntry:
     room_id: str
     via: Sequence[str]
+    depth: int = 0
 
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
@@ -662,7 +852,12 @@ class _RoomEntry:
     # An iterable of the sorted, stripped children events for children of this room.
     #
     # This may not include all children.
-    children: Collection[JsonDict] = ()
+    children: Sequence[JsonDict] = ()
+
+    def as_json(self) -> JsonDict:
+        result = dict(self.room)
+        result["children_state"] = self.children
+        return result
 
 
 def _has_valid_via(e: EventBase) -> bool:
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index f887970b76..b28b72bfbd 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -1445,6 +1445,46 @@ class RoomSpaceSummaryRestServlet(RestServlet):
         )
 
 
+class RoomHierarchyRestServlet(RestServlet):
+    PATTERNS = (
+        re.compile(
+            "^/_matrix/client/unstable/org.matrix.msc2946"
+            "/rooms/(?P<room_id>[^/]*)/hierarchy$"
+        ),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self._auth = hs.get_auth()
+        self._space_summary_handler = hs.get_space_summary_handler()
+
+    async def on_GET(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self._auth.get_user_by_req(request, allow_guest=True)
+
+        max_depth = parse_integer(request, "max_depth")
+        if max_depth is not None and max_depth < 0:
+            raise SynapseError(
+                400, "'max_depth' must be a non-negative integer", Codes.BAD_JSON
+            )
+
+        limit = parse_integer(request, "limit")
+        if limit is not None and limit <= 0:
+            raise SynapseError(
+                400, "'limit' must be a positive integer", Codes.BAD_JSON
+            )
+
+        return 200, await self._space_summary_handler.get_room_hierarchy(
+            requester.user.to_string(),
+            room_id,
+            suggested_only=parse_boolean(request, "suggested_only", default=False),
+            max_depth=max_depth,
+            limit=limit,
+            from_token=parse_string(request, "from"),
+        )
+
+
 def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     msc2716_enabled = hs.config.experimental.msc2716_enabled
 
@@ -1463,6 +1503,7 @@ def register_servlets(hs: "HomeServer", http_server, is_worker=False):
     RoomTypingRestServlet(hs).register(http_server)
     RoomEventContextServlet(hs).register(http_server)
     RoomSpaceSummaryRestServlet(hs).register(http_server)
+    RoomHierarchyRestServlet(hs).register(http_server)
     RoomEventServlet(hs).register(http_server)
     JoinedRoomsRestServlet(hs).register(http_server)
     RoomAliasListServlet(hs).register(http_server)
diff --git a/tests/handlers/test_space_summary.py b/tests/handlers/test_space_summary.py
index f470c81ea2..255dd17f86 100644
--- a/tests/handlers/test_space_summary.py
+++ b/tests/handlers/test_space_summary.py
@@ -23,7 +23,7 @@ from synapse.api.constants import (
     RestrictedJoinRuleTypes,
     RoomTypes,
 )
-from synapse.api.errors import AuthError
+from synapse.api.errors import AuthError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.handlers.space_summary import _child_events_comparison_key, _RoomEntry
@@ -123,32 +123,83 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         self.room = self.helper.create_room_as(self.user, tok=self.token)
         self._add_child(self.space, self.room, self.token)
 
-    def _add_child(self, space_id: str, room_id: str, token: str) -> None:
+    def _add_child(
+        self, space_id: str, room_id: str, token: str, order: Optional[str] = None
+    ) -> None:
         """Add a child room to a space."""
+        content = {"via": [self.hs.hostname]}
+        if order is not None:
+            content["order"] = order
         self.helper.send_state(
             space_id,
             event_type=EventTypes.SpaceChild,
-            body={"via": [self.hs.hostname]},
+            body=content,
             tok=token,
             state_key=room_id,
         )
 
-    def _assert_rooms(self, result: JsonDict, rooms: Iterable[str]) -> None:
-        """Assert that the expected room IDs are in the response."""
-        self.assertCountEqual([room.get("room_id") for room in result["rooms"]], rooms)
-
-    def _assert_events(
-        self, result: JsonDict, events: Iterable[Tuple[str, str]]
+    def _assert_rooms(
+        self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
     ) -> None:
-        """Assert that the expected parent / child room IDs are in the response."""
+        """
+        Assert that the expected room IDs and events are in the response.
+
+        Args:
+            result: The result from the API call.
+            rooms_and_children: An iterable of tuples where each tuple is:
+                The expected room ID.
+                The expected IDs of any children rooms.
+        """
+        room_ids = []
+        children_ids = []
+        for room_id, children in rooms_and_children:
+            room_ids.append(room_id)
+            if children:
+                children_ids.extend([(room_id, child_id) for child_id in children])
+        self.assertCountEqual(
+            [room.get("room_id") for room in result["rooms"]], room_ids
+        )
         self.assertCountEqual(
             [
                 (event.get("room_id"), event.get("state_key"))
                 for event in result["events"]
             ],
-            events,
+            children_ids,
         )
 
+    def _assert_hierarchy(
+        self, result: JsonDict, rooms_and_children: Iterable[Tuple[str, Iterable[str]]]
+    ) -> None:
+        """
+        Assert that the expected room IDs are in the response.
+
+        Args:
+            result: The result from the API call.
+            rooms_and_children: An iterable of tuples where each tuple is:
+                The expected room ID.
+                The expected IDs of any children rooms.
+        """
+        result_room_ids = []
+        result_children_ids = []
+        for result_room in result["rooms"]:
+            result_room_ids.append(result_room["room_id"])
+            result_children_ids.append(
+                [
+                    (cs["room_id"], cs["state_key"])
+                    for cs in result_room.get("children_state")
+                ]
+            )
+
+        room_ids = []
+        children_ids = []
+        for room_id, children in rooms_and_children:
+            room_ids.append(room_id)
+            children_ids.append([(room_id, child_id) for child_id in children])
+
+        # Note that order matters.
+        self.assertEqual(result_room_ids, room_ids)
+        self.assertEqual(result_children_ids, children_ids)
+
     def _poke_fed_invite(self, room_id: str, from_user: str) -> None:
         """
         Creates a invite (as if received over federation) for the room from the
@@ -184,8 +235,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         result = self.get_success(self.handler.get_space_summary(self.user, self.space))
         # The result should have the space and the room in it, along with a link
         # from space -> room.
-        self._assert_rooms(result, [self.space, self.room])
-        self._assert_events(result, [(self.space, self.room)])
+        expected = [(self.space, [self.room]), (self.room, ())]
+        self._assert_rooms(result, expected)
+
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+        self._assert_hierarchy(result, expected)
 
     def test_visibility(self):
         """A user not in a space cannot inspect it."""
@@ -194,6 +250,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
 
         # The user cannot see the space.
         self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+        self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
 
         # If the space is made world-readable it should return a result.
         self.helper.send_state(
@@ -203,8 +260,11 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             tok=self.token,
         )
         result = self.get_success(self.handler.get_space_summary(user2, self.space))
-        self._assert_rooms(result, [self.space, self.room])
-        self._assert_events(result, [(self.space, self.room)])
+        expected = [(self.space, [self.room]), (self.room, ())]
+        self._assert_rooms(result, expected)
+
+        result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+        self._assert_hierarchy(result, expected)
 
         # Make it not world-readable again and confirm it results in an error.
         self.helper.send_state(
@@ -214,12 +274,15 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
             tok=self.token,
         )
         self.get_failure(self.handler.get_space_summary(user2, self.space), AuthError)
+        self.get_failure(self.handler.get_room_hierarchy(user2, self.space), AuthError)
 
         # Join the space and results should be returned.
         self.helper.join(self.space, user2, tok=token2)
         result = self.get_success(self.handler.get_space_summary(user2, self.space))
-        self._assert_rooms(result, [self.space, self.room])
-        self._assert_events(result, [(self.space, self.room)])
+        self._assert_rooms(result, expected)
+
+        result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+        self._assert_hierarchy(result, expected)
 
     def _create_room_with_join_rule(
         self, join_rule: str, room_version: Optional[str] = None, **extra_content
@@ -290,34 +353,33 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         # Join the space.
         self.helper.join(self.space, user2, tok=token2)
         result = self.get_success(self.handler.get_space_summary(user2, self.space))
-
-        self._assert_rooms(
-            result,
-            [
+        expected = [
+            (
                 self.space,
-                self.room,
-                public_room,
-                knock_room,
-                invited_room,
-                restricted_accessible_room,
-                world_readable_room,
-                joined_room,
-            ],
-        )
-        self._assert_events(
-            result,
-            [
-                (self.space, self.room),
-                (self.space, public_room),
-                (self.space, knock_room),
-                (self.space, not_invited_room),
-                (self.space, invited_room),
-                (self.space, restricted_room),
-                (self.space, restricted_accessible_room),
-                (self.space, world_readable_room),
-                (self.space, joined_room),
-            ],
-        )
+                [
+                    self.room,
+                    public_room,
+                    knock_room,
+                    not_invited_room,
+                    invited_room,
+                    restricted_room,
+                    restricted_accessible_room,
+                    world_readable_room,
+                    joined_room,
+                ],
+            ),
+            (self.room, ()),
+            (public_room, ()),
+            (knock_room, ()),
+            (invited_room, ()),
+            (restricted_accessible_room, ()),
+            (world_readable_room, ()),
+            (joined_room, ()),
+        ]
+        self._assert_rooms(result, expected)
+
+        result = self.get_success(self.handler.get_room_hierarchy(user2, self.space))
+        self._assert_hierarchy(result, expected)
 
     def test_complex_space(self):
         """
@@ -349,19 +411,145 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
         result = self.get_success(self.handler.get_space_summary(self.user, self.space))
 
         # The result should include each room a single time and each link.
-        self._assert_rooms(result, [self.space, self.room, subspace, subroom])
-        self._assert_events(
-            result,
-            [
-                (self.space, self.room),
-                (self.space, room2),
-                (self.space, subspace),
-                (subspace, subroom),
-                (subspace, self.room),
-                (subspace, room2),
-            ],
+        expected = [
+            (self.space, [self.room, room2, subspace]),
+            (self.room, ()),
+            (subspace, [subroom, self.room, room2]),
+            (subroom, ()),
+        ]
+        self._assert_rooms(result, expected)
+
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space)
+        )
+        self._assert_hierarchy(result, expected)
+
+    def test_pagination(self):
+        """Test simple pagination works."""
+        room_ids = []
+        for i in range(1, 10):
+            room = self.helper.create_room_as(self.user, tok=self.token)
+            self._add_child(self.space, room, self.token, order=str(i))
+            room_ids.append(room)
+        # The room created initially doesn't have an order, so comes last.
+        room_ids.append(self.room)
+
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+        )
+        # The result should have the space and all of the links, plus some of the
+        # rooms and a pagination token.
+        expected = [(self.space, room_ids)] + [
+            (room_id, ()) for room_id in room_ids[:6]
+        ]
+        self._assert_hierarchy(result, expected)
+        self.assertIn("next_token", result)
+
+        # Check the next page.
+        result = self.get_success(
+            self.handler.get_room_hierarchy(
+                self.user, self.space, limit=5, from_token=result["next_token"]
+            )
+        )
+        # The result should have the space and the room in it, along with a link
+        # from space -> room.
+        expected = [(room_id, ()) for room_id in room_ids[6:]]
+        self._assert_hierarchy(result, expected)
+        self.assertNotIn("next_token", result)
+
+    def test_invalid_pagination_token(self):
+        """"""
+        room_ids = []
+        for i in range(1, 10):
+            room = self.helper.create_room_as(self.user, tok=self.token)
+            self._add_child(self.space, room, self.token, order=str(i))
+            room_ids.append(room)
+        # The room created initially doesn't have an order, so comes last.
+        room_ids.append(self.room)
+
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space, limit=7)
+        )
+        self.assertIn("next_token", result)
+
+        # Changing the room ID, suggested-only, or max-depth causes an error.
+        self.get_failure(
+            self.handler.get_room_hierarchy(
+                self.user, self.room, from_token=result["next_token"]
+            ),
+            SynapseError,
+        )
+        self.get_failure(
+            self.handler.get_room_hierarchy(
+                self.user,
+                self.space,
+                suggested_only=True,
+                from_token=result["next_token"],
+            ),
+            SynapseError,
+        )
+        self.get_failure(
+            self.handler.get_room_hierarchy(
+                self.user, self.space, max_depth=0, from_token=result["next_token"]
+            ),
+            SynapseError,
         )
 
+        # An invalid token is ignored.
+        self.get_failure(
+            self.handler.get_room_hierarchy(self.user, self.space, from_token="foo"),
+            SynapseError,
+        )
+
+    def test_max_depth(self):
+        """Create a deep tree to test the max depth against."""
+        spaces = [self.space]
+        rooms = [self.room]
+        for _ in range(5):
+            spaces.append(
+                self.helper.create_room_as(
+                    self.user,
+                    tok=self.token,
+                    extra_content={
+                        "creation_content": {
+                            EventContentFields.ROOM_TYPE: RoomTypes.SPACE
+                        }
+                    },
+                )
+            )
+            self._add_child(spaces[-2], spaces[-1], self.token)
+            rooms.append(self.helper.create_room_as(self.user, tok=self.token))
+            self._add_child(spaces[-1], rooms[-1], self.token)
+
+        # Test just the space itself.
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space, max_depth=0)
+        )
+        expected = [(spaces[0], [rooms[0], spaces[1]])]
+        self._assert_hierarchy(result, expected)
+
+        # A single additional layer.
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space, max_depth=1)
+        )
+        expected += [
+            (rooms[0], ()),
+            (spaces[1], [rooms[1], spaces[2]]),
+        ]
+        self._assert_hierarchy(result, expected)
+
+        # A few layers.
+        result = self.get_success(
+            self.handler.get_room_hierarchy(self.user, self.space, max_depth=3)
+        )
+        expected += [
+            (rooms[1], ()),
+            (spaces[2], [rooms[2], spaces[3]]),
+            (rooms[2], ()),
+            (spaces[3], [rooms[3], spaces[4]]),
+        ]
+        self._assert_hierarchy(result, expected)
+
     def test_fed_complex(self):
         """
         Return data over federation and ensure that it is handled properly.
@@ -417,15 +605,13 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                 self.handler.get_space_summary(self.user, self.space)
             )
 
-        self._assert_rooms(result, [self.space, self.room, subspace, subroom])
-        self._assert_events(
-            result,
-            [
-                (self.space, self.room),
-                (self.space, subspace),
-                (subspace, subroom),
-            ],
-        )
+        expected = [
+            (self.space, [self.room, subspace]),
+            (self.room, ()),
+            (subspace, [subroom]),
+            (subroom, ()),
+        ]
+        self._assert_rooms(result, expected)
 
     def test_fed_filtering(self):
         """
@@ -554,35 +740,30 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                 self.handler.get_space_summary(self.user, self.space)
             )
 
-        self._assert_rooms(
-            result,
-            [
-                self.space,
-                self.room,
+        expected = [
+            (self.space, [self.room, subspace]),
+            (self.room, ()),
+            (
                 subspace,
-                public_room,
-                knock_room,
-                invited_room,
-                restricted_accessible_room,
-                world_readable_room,
-                joined_room,
-            ],
-        )
-        self._assert_events(
-            result,
-            [
-                (self.space, self.room),
-                (self.space, subspace),
-                (subspace, public_room),
-                (subspace, knock_room),
-                (subspace, not_invited_room),
-                (subspace, invited_room),
-                (subspace, restricted_room),
-                (subspace, restricted_accessible_room),
-                (subspace, world_readable_room),
-                (subspace, joined_room),
-            ],
-        )
+                [
+                    public_room,
+                    knock_room,
+                    not_invited_room,
+                    invited_room,
+                    restricted_room,
+                    restricted_accessible_room,
+                    world_readable_room,
+                    joined_room,
+                ],
+            ),
+            (public_room, ()),
+            (knock_room, ()),
+            (invited_room, ()),
+            (restricted_accessible_room, ()),
+            (world_readable_room, ()),
+            (joined_room, ()),
+        ]
+        self._assert_rooms(result, expected)
 
     def test_fed_invited(self):
         """
@@ -623,18 +804,9 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
                 self.handler.get_space_summary(self.user, self.space)
             )
 
-        self._assert_rooms(
-            result,
-            [
-                self.space,
-                self.room,
-                fed_room,
-            ],
-        )
-        self._assert_events(
-            result,
-            [
-                (self.space, self.room),
-                (self.space, fed_room),
-            ],
-        )
+        expected = [
+            (self.space, [self.room, fed_room]),
+            (self.room, ()),
+            (fed_room, ()),
+        ]
+        self._assert_rooms(result, expected)