summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tests/handlers/test_space_hierarchy.py64
-rw-r--r--tests/rest/admin/test_space.py118
2 files changed, 177 insertions, 5 deletions
diff --git a/tests/handlers/test_space_hierarchy.py b/tests/handlers/test_space_hierarchy.py
index 548173d1db..63bc93d558 100644
--- a/tests/handlers/test_space_hierarchy.py
+++ b/tests/handlers/test_space_hierarchy.py
@@ -12,11 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Optional
+from typing import Dict, Iterable, Mapping, NoReturn, Optional, Sequence, Tuple
+from unittest import mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventContentFields, EventTypes, RoomTypes
+from synapse.handlers.space_hierarchy import SpaceHierarchyHandler
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
@@ -120,6 +122,66 @@ class SpaceDescendantsTestCase(unittest.HomeserverTestCase):
         self.assertEqual(descendants, [(space_id, []), (room_id, [self.hs.hostname])])
         self.assertEqual(inaccessible_room_ids, [room_id])
 
+    def test_remote_space_with_federation_enabled(self):
+        """Tests iteration over a remote space with federation enabled."""
+        space_id = "!space:remote"
+        room_id = "!room:remote"
+
+        async def _get_space_children_remote(
+            _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str]
+        ) -> Tuple[
+            Sequence[Tuple[str, Iterable[str]]], Mapping[str, Optional[JsonDict]]
+        ]:
+            if space_id == "!space:remote":
+                self.assertEqual(via, ["remote"])
+                return [("!room:remote", ["remote"])], {}
+            elif space_id == "!room:remote":
+                self.assertEqual(via, ["remote"])
+                return [], {}
+            else:
+                self.fail(
+                    f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call"
+                )
+                raise  # `fail` is missing type hints
+
+        with mock.patch(
+            "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote",
+            new=_get_space_children_remote,
+        ):
+            descendants, inaccessible_room_ids = self.get_success(
+                self.handler.get_space_descendants(
+                    space_id, via=["remote"], enable_federation=True
+                )
+            )
+
+        self.assertEqual(descendants, [(space_id, ["remote"]), (room_id, ["remote"])])
+        self.assertEqual(inaccessible_room_ids, [space_id, room_id])
+
+    def test_remote_space_with_federation_disabled(self):
+        """Tests iteration over a remote space with federation disabled."""
+        space_id = "!space:remote"
+
+        async def _get_space_children_remote(
+            _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str]
+        ) -> NoReturn:
+            self.fail(
+                f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call"
+            )
+            raise  # `fail` is missing type hints
+
+        with mock.patch(
+            "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote",
+            new=_get_space_children_remote,
+        ):
+            descendants, inaccessible_room_ids = self.get_success(
+                self.handler.get_space_descendants(
+                    space_id, via=["remote"], enable_federation=False
+                )
+            )
+
+        self.assertEqual(descendants, [(space_id, ["remote"])])
+        self.assertEqual(inaccessible_room_ids, [space_id])
+
     def test_cycle(self):
         """Tests iteration over a cyclic space."""
         # space_id
diff --git a/tests/rest/admin/test_space.py b/tests/rest/admin/test_space.py
index 70d6776258..2aa5a26142 100644
--- a/tests/rest/admin/test_space.py
+++ b/tests/rest/admin/test_space.py
@@ -12,7 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, Optional, Tuple, Union
+from typing import (
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    NoReturn,
+    Optional,
+    Sequence,
+    Tuple,
+    Union,
+)
+from unittest import mock
 
 from typing_extensions import Literal
 
@@ -28,6 +39,7 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.room_versions import RoomVersions
+from synapse.handlers.space_hierarchy import SpaceHierarchyHandler
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
@@ -73,12 +85,17 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase):
         self.subspace_id = self._create_space((JoinRules.RESTRICTED, self.space_id))
         self._add_child(self.space_id, self.subspace_id)
 
-    def _add_child(self, space_id: str, room_id: str) -> None:
+    def _add_child(
+        self, space_id: str, room_id: str, via: Optional[List[str]] = None
+    ) -> None:
         """Adds a room to a space."""
+        if via is None:
+            via = [self.hs.hostname]
+
         self.helper.send_state(
             space_id,
             event_type=EventTypes.SpaceChild,
-            body={"via": [self.hs.hostname]},
+            body={"via": via},
             tok=self.space_owner_user_tok,
             state_key=room_id,
         )
@@ -141,13 +158,26 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase):
 
         return room_id
 
-    def _remove_from_space(self, user_id: str) -> JsonDict:
+    def _remove_from_space(
+        self,
+        user_id: str,
+        space_id: Optional[str] = None,
+        include_remote_spaces: Optional[bool] = None,
+    ) -> JsonDict:
         """Removes the given user from the test space."""
+        if space_id is None:
+            space_id = self.space_id
+
+        content: Union[bytes, JsonDict] = b""
+        if include_remote_spaces is not None:
+            content = {"include_remote_spaces": include_remote_spaces}
+
         url = f"/_synapse/admin/v1/rooms/{self.space_id}/hierarchy/members/{user_id}"
         channel = self.make_request(
             "DELETE",
             url.encode("ascii"),
             access_token=self.admin_user_tok,
+            content=content,
         )
 
         self.assertEqual(200, channel.code, channel.json_body)
@@ -287,3 +317,83 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase):
             )
         )
         self.assertEqual(membership, Membership.LEAVE)
+
+    def test_remote_space(self) -> None:
+        """Tests that the user is made to leave rooms in a remote space."""
+        remote_space_id = "!space:remote"
+        self._add_child(self.subspace_id, remote_space_id, via=["remote"])
+
+        restricted_room_id = self._create_room((JoinRules.RESTRICTED, self.space_id))
+        self.helper.join(restricted_room_id, self.target_user, tok=self.target_user_tok)
+
+        async def _get_space_children_remote(
+            _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str]
+        ) -> Tuple[
+            Sequence[Tuple[str, Iterable[str]]], Mapping[str, Optional[JsonDict]]
+        ]:
+            self.assertEqual(space_id, remote_space_id)
+            self.assertEqual(via, ["remote"])
+
+            return [(restricted_room_id, [self.hs.hostname])], {}
+
+        with mock.patch(
+            "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote",
+            new=_get_space_children_remote,
+        ):
+            response = self._remove_from_space(
+                self.target_user, space_id="!space:remote", include_remote_spaces=True
+            )
+        self.assertEqual(
+            response,
+            {
+                "left_rooms": [self.space_id, restricted_room_id],
+                "inaccessible_rooms": [remote_space_id],
+                "failed_rooms": {},
+            },
+        )
+
+        membership, _ = self.get_success(
+            self.store.get_local_current_membership_for_user_in_room(
+                self.target_user, restricted_room_id
+            )
+        )
+        self.assertEqual(membership, Membership.LEAVE)
+
+    def test_remote_spaces_excluded(self) -> None:
+        """Tests the exclusion of remote spaces."""
+        remote_space_id = "!space:remote"
+        self._add_child(self.subspace_id, remote_space_id, via=["remote"])
+
+        restricted_room_id = self._create_room((JoinRules.RESTRICTED, self.space_id))
+        self.helper.join(restricted_room_id, self.target_user, tok=self.target_user_tok)
+
+        async def _get_space_children_remote(
+            _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str]
+        ) -> NoReturn:
+            self.fail(
+                f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call"
+            )
+            raise  # `fail` is missing type hints
+
+        with mock.patch(
+            "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote",
+            new=_get_space_children_remote,
+        ):
+            response = self._remove_from_space(
+                self.target_user, space_id="!space:remote", include_remote_spaces=False
+            )
+        self.assertEqual(
+            response,
+            {
+                "left_rooms": [self.space_id],
+                "inaccessible_rooms": [remote_space_id],
+                "failed_rooms": {},
+            },
+        )
+
+        membership, _ = self.get_success(
+            self.store.get_local_current_membership_for_user_in_room(
+                self.target_user, restricted_room_id
+            )
+        )
+        self.assertEqual(membership, Membership.JOIN)