summary refs log tree commit diff
path: root/tests/test_visibility.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_visibility.py320
1 files changed, 228 insertions, 92 deletions
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e51f72d65f..3e2100eab4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,13 +21,19 @@ import logging
 from typing import Optional
 from unittest.mock import patch
 
+from synapse.api.constants import EventUnsignedContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, create_requester
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import create_requester
 from synapse.visibility import filter_events_for_client, filter_events_for_server
 
 from tests import unittest
+from tests.test_utils.event_injection import inject_event, inject_member_event
+from tests.unittest import HomeserverTestCase
 from tests.utils import create_room
 
 logger = logging.getLogger(__name__)
@@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self._inject_visibility("@admin:hs", "joined")
+        self.get_success(
+            inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined")
+        )
         for i in range(10):
-            self._inject_room_member("@resident%i:hs" % i)
+            self.get_success(
+                inject_member_event(
+                    self.hs,
+                    TEST_ROOM_ID,
+                    "@resident%i:hs" % i,
+                    "join",
+                )
+            )
 
         events_to_filter = []
 
         for i in range(10):
-            user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
-            evt = self._inject_room_member(user, extra_content={"a": "b"})
+            evt = self.get_success(
+                inject_member_event(
+                    self.hs,
+                    TEST_ROOM_ID,
+                    "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"),
+                    "join",
+                    extra_content={"a": "b"},
+                )
+            )
             events_to_filter.append(evt)
 
         filtered = self.get_success(
@@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
     def test_filter_outlier(self) -> None:
         # outlier events must be returned, for the good of the collective federation
-        self._inject_room_member("@resident:remote_hs")
-        self._inject_visibility("@resident:remote_hs", "joined")
+        self.get_success(
+            inject_member_event(
+                self.hs,
+                TEST_ROOM_ID,
+                "@resident:remote_hs",
+                "join",
+            )
+        )
+        self.get_success(
+            inject_visibility_event(
+                self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined"
+            )
+        )
 
         outlier = self._inject_outlier()
         self.assertEqual(
@@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         )
 
         # it should also work when there are other events in the list
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
 
         filtered = self.get_success(
             filter_events_for_server(
@@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # change in the middle of them.
         events_to_filter = []
 
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@erased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_room_member("@joiner:remote_hs")
+        evt = self.get_success(
+            inject_member_event(
+                self.hs,
+                TEST_ROOM_ID,
+                "@joiner:remote_hs",
+                "join",
+            )
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@erased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+        )
         events_to_filter.append(evt)
 
         # the erasey user gets erased
@@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         for i in (1, 4):
             self.assertNotIn("body", filtered[i].content)
 
-    def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
-        content = {"history_visibility": visibility}
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.history_visibility",
-                "sender": user_id,
-                "state_key": "",
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
-    def _inject_room_member(
-        self,
-        user_id: str,
-        membership: str = "join",
-        extra_content: Optional[JsonDict] = None,
-    ) -> EventBase:
-        content = {"membership": membership}
-        content.update(extra_content or {})
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.member",
-                "sender": user_id,
-                "state_key": user_id,
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
-    def _inject_message(
-        self, user_id: str, content: Optional[JsonDict] = None
-    ) -> EventBase:
-        if content is None:
-            content = {"body": "testytest", "msgtype": "m.text"}
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.message",
-                "sender": user_id,
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
     def _inject_outlier(self) -> EventBase:
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
@@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         return event
 
 
-class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+class FilterEventsForClientTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def test_joined_history_visibility(self) -> None:
+        # User joins and leaves room. Should be able to see the join and leave,
+        # and messages sent between the two, but not before or after.
+
+        self.register_user("resident", "p1")
+        resident_token = self.login("resident", "p1")
+        room_id = self.helper.create_room_as("resident", tok=resident_token)
+
+        self.get_success(
+            inject_visibility_event(self.hs, room_id, "@resident:test", "joined")
+        )
+        before_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="before")
+        )
+        join_event = self.get_success(
+            inject_member_event(self.hs, room_id, "@joiner:test", "join")
+        )
+        during_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="during")
+        )
+        leave_event = self.get_success(
+            inject_member_event(self.hs, room_id, "@joiner:test", "leave")
+        )
+        after_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="after")
+        )
+
+        # We have to reload the events from the db, to ensure that prev_content is
+        # populated.
+        events_to_filter = [
+            self.get_success(
+                self.hs.get_storage_controllers().main.get_event(
+                    e.event_id,
+                    get_prev_content=True,
+                )
+            )
+            for e in [
+                before_event,
+                join_event,
+                during_event,
+                leave_event,
+                after_event,
+            ]
+        ]
+
+        # Now run the events through the filter, and check that we can see the events
+        # we expect, and that the membership prop is as expected.
+        #
+        # We deliberately do the queries for both users upfront; this simulates
+        # concurrent queries on the server, and helps ensure that we aren't
+        # accidentally serving the same event object (with the same unsigned.membership
+        # property) to both users.
+        joiner_filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@joiner:test",
+                events_to_filter,
+                msc4115_membership_on_events=True,
+            )
+        )
+        resident_filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@resident:test",
+                events_to_filter,
+                msc4115_membership_on_events=True,
+            )
+        )
+
+        # The joiner should be able to seem the join and leave,
+        # and messages sent between the two, but not before or after.
+        self.assertEqual(
+            [e.event_id for e in [join_event, during_event, leave_event]],
+            [e.event_id for e in joiner_filtered_events],
+        )
+        self.assertEqual(
+            ["join", "join", "leave"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in joiner_filtered_events
+            ],
+        )
+
+        # The resident user should see all the events.
+        self.assertEqual(
+            [
+                e.event_id
+                for e in [
+                    before_event,
+                    join_event,
+                    during_event,
+                    leave_event,
+                    after_event,
+                ]
+            ],
+            [e.event_id for e in resident_filtered_events],
+        )
+        self.assertEqual(
+            ["join", "join", "join", "join", "join"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in resident_filtered_events
+            ],
+        )
+
+
+class FilterEventsOutOfBandEventsForClientTestCase(
+    unittest.FederatingHomeserverTestCase
+):
     def test_out_of_band_invite_rejection(self) -> None:
         # this is where we have received an invite event over federation, and then
         # rejected it.
@@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         )
 
         # the invited user should be able to see both the invite and the rejection
+        filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@user:test",
+                [invite_event, reject_event],
+                msc4115_membership_on_events=True,
+            )
+        )
         self.assertEqual(
-            self.get_success(
-                filter_events_for_client(
-                    self.hs.get_storage_controllers(),
-                    "@user:test",
-                    [invite_event, reject_event],
-                )
-            ),
-            [invite_event, reject_event],
+            [e.event_id for e in filtered_events],
+            [e.event_id for e in [invite_event, reject_event]],
+        )
+        self.assertEqual(
+            ["invite", "leave"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in filtered_events
+            ],
         )
 
         # other users should see neither
@@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
                     self.hs.get_storage_controllers(),
                     "@other:test",
                     [invite_event, reject_event],
+                    msc4115_membership_on_events=True,
                 )
             ),
             [],
         )
+
+
+async def inject_visibility_event(
+    hs: HomeServer,
+    room_id: str,
+    sender: str,
+    visibility: str,
+) -> EventBase:
+    return await inject_event(
+        hs,
+        type="m.room.history_visibility",
+        sender=sender,
+        state_key="",
+        room_id=room_id,
+        content={"history_visibility": visibility},
+    )
+
+
+async def inject_message_event(
+    hs: HomeServer,
+    room_id: str,
+    sender: str,
+    body: Optional[str] = "testytest",
+) -> EventBase:
+    return await inject_event(
+        hs,
+        type="m.room.message",
+        sender=sender,
+        room_id=room_id,
+        content={"body": body, "msgtype": "m.text"},
+    )