diff options
-rw-r--r-- | tests/handlers/test_space_hierarchy.py | 64 | ||||
-rw-r--r-- | tests/rest/admin/test_space.py | 118 |
2 files changed, 177 insertions, 5 deletions
diff --git a/tests/handlers/test_space_hierarchy.py b/tests/handlers/test_space_hierarchy.py index 548173d1db..63bc93d558 100644 --- a/tests/handlers/test_space_hierarchy.py +++ b/tests/handlers/test_space_hierarchy.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +from typing import Dict, Iterable, Mapping, NoReturn, Optional, Sequence, Tuple +from unittest import mock from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventContentFields, EventTypes, RoomTypes +from synapse.handlers.space_hierarchy import SpaceHierarchyHandler from synapse.rest import admin from synapse.rest.client import login, room from synapse.server import HomeServer @@ -120,6 +122,66 @@ class SpaceDescendantsTestCase(unittest.HomeserverTestCase): self.assertEqual(descendants, [(space_id, []), (room_id, [self.hs.hostname])]) self.assertEqual(inaccessible_room_ids, [room_id]) + def test_remote_space_with_federation_enabled(self): + """Tests iteration over a remote space with federation enabled.""" + space_id = "!space:remote" + room_id = "!room:remote" + + async def _get_space_children_remote( + _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str] + ) -> Tuple[ + Sequence[Tuple[str, Iterable[str]]], Mapping[str, Optional[JsonDict]] + ]: + if space_id == "!space:remote": + self.assertEqual(via, ["remote"]) + return [("!room:remote", ["remote"])], {} + elif space_id == "!room:remote": + self.assertEqual(via, ["remote"]) + return [], {} + else: + self.fail( + f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call" + ) + raise # `fail` is missing type hints + + with mock.patch( + "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote", + new=_get_space_children_remote, + ): + descendants, inaccessible_room_ids = self.get_success( + self.handler.get_space_descendants( + space_id, via=["remote"], enable_federation=True + ) + ) + + self.assertEqual(descendants, [(space_id, ["remote"]), (room_id, ["remote"])]) + self.assertEqual(inaccessible_room_ids, [space_id, room_id]) + + def test_remote_space_with_federation_disabled(self): + """Tests iteration over a remote space with federation disabled.""" + space_id = "!space:remote" + + async def _get_space_children_remote( + _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str] + ) -> NoReturn: + self.fail( + f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call" + ) + raise # `fail` is missing type hints + + with mock.patch( + "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote", + new=_get_space_children_remote, + ): + descendants, inaccessible_room_ids = self.get_success( + self.handler.get_space_descendants( + space_id, via=["remote"], enable_federation=False + ) + ) + + self.assertEqual(descendants, [(space_id, ["remote"])]) + self.assertEqual(inaccessible_room_ids, [space_id]) + def test_cycle(self): """Tests iteration over a cyclic space.""" # space_id diff --git a/tests/rest/admin/test_space.py b/tests/rest/admin/test_space.py index 70d6776258..2aa5a26142 100644 --- a/tests/rest/admin/test_space.py +++ b/tests/rest/admin/test_space.py @@ -12,7 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Tuple, Union +from typing import ( + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Sequence, + Tuple, + Union, +) +from unittest import mock from typing_extensions import Literal @@ -28,6 +39,7 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.room_versions import RoomVersions +from synapse.handlers.space_hierarchy import SpaceHierarchyHandler from synapse.rest.client import login, room from synapse.server import HomeServer from synapse.types import JsonDict @@ -73,12 +85,17 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase): self.subspace_id = self._create_space((JoinRules.RESTRICTED, self.space_id)) self._add_child(self.space_id, self.subspace_id) - def _add_child(self, space_id: str, room_id: str) -> None: + def _add_child( + self, space_id: str, room_id: str, via: Optional[List[str]] = None + ) -> None: """Adds a room to a space.""" + if via is None: + via = [self.hs.hostname] + self.helper.send_state( space_id, event_type=EventTypes.SpaceChild, - body={"via": [self.hs.hostname]}, + body={"via": via}, tok=self.space_owner_user_tok, state_key=room_id, ) @@ -141,13 +158,26 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase): return room_id - def _remove_from_space(self, user_id: str) -> JsonDict: + def _remove_from_space( + self, + user_id: str, + space_id: Optional[str] = None, + include_remote_spaces: Optional[bool] = None, + ) -> JsonDict: """Removes the given user from the test space.""" + if space_id is None: + space_id = self.space_id + + content: Union[bytes, JsonDict] = b"" + if include_remote_spaces is not None: + content = {"include_remote_spaces": include_remote_spaces} + url = f"/_synapse/admin/v1/rooms/{self.space_id}/hierarchy/members/{user_id}" channel = self.make_request( "DELETE", url.encode("ascii"), access_token=self.admin_user_tok, + content=content, ) self.assertEqual(200, channel.code, channel.json_body) @@ -287,3 +317,83 @@ class RemoveSpaceMemberTestCase(unittest.HomeserverTestCase): ) ) self.assertEqual(membership, Membership.LEAVE) + + def test_remote_space(self) -> None: + """Tests that the user is made to leave rooms in a remote space.""" + remote_space_id = "!space:remote" + self._add_child(self.subspace_id, remote_space_id, via=["remote"]) + + restricted_room_id = self._create_room((JoinRules.RESTRICTED, self.space_id)) + self.helper.join(restricted_room_id, self.target_user, tok=self.target_user_tok) + + async def _get_space_children_remote( + _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str] + ) -> Tuple[ + Sequence[Tuple[str, Iterable[str]]], Mapping[str, Optional[JsonDict]] + ]: + self.assertEqual(space_id, remote_space_id) + self.assertEqual(via, ["remote"]) + + return [(restricted_room_id, [self.hs.hostname])], {} + + with mock.patch( + "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote", + new=_get_space_children_remote, + ): + response = self._remove_from_space( + self.target_user, space_id="!space:remote", include_remote_spaces=True + ) + self.assertEqual( + response, + { + "left_rooms": [self.space_id, restricted_room_id], + "inaccessible_rooms": [remote_space_id], + "failed_rooms": {}, + }, + ) + + membership, _ = self.get_success( + self.store.get_local_current_membership_for_user_in_room( + self.target_user, restricted_room_id + ) + ) + self.assertEqual(membership, Membership.LEAVE) + + def test_remote_spaces_excluded(self) -> None: + """Tests the exclusion of remote spaces.""" + remote_space_id = "!space:remote" + self._add_child(self.subspace_id, remote_space_id, via=["remote"]) + + restricted_room_id = self._create_room((JoinRules.RESTRICTED, self.space_id)) + self.helper.join(restricted_room_id, self.target_user, tok=self.target_user_tok) + + async def _get_space_children_remote( + _self: SpaceHierarchyHandler, space_id: str, via: Iterable[str] + ) -> NoReturn: + self.fail( + f"Unexpected _get_space_children_remote({space_id!r}, {via!r}) call" + ) + raise # `fail` is missing type hints + + with mock.patch( + "synapse.handlers.space_hierarchy.SpaceHierarchyHandler._get_space_children_remote", + new=_get_space_children_remote, + ): + response = self._remove_from_space( + self.target_user, space_id="!space:remote", include_remote_spaces=False + ) + self.assertEqual( + response, + { + "left_rooms": [self.space_id], + "inaccessible_rooms": [remote_space_id], + "failed_rooms": {}, + }, + ) + + membership, _ = self.get_success( + self.store.get_local_current_membership_for_user_in_room( + self.target_user, restricted_room_id + ) + ) + self.assertEqual(membership, Membership.JOIN) |