diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e51f72d65f..3e2100eab4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,13 +21,19 @@ import logging
from typing import Optional
from unittest.mock import patch
+from synapse.api.constants import EventUnsignedContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, create_requester
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import create_requester
from synapse.visibility import filter_events_for_client, filter_events_for_server
from tests import unittest
+from tests.test_utils.event_injection import inject_event, inject_member_event
+from tests.unittest import HomeserverTestCase
from tests.utils import create_room
logger = logging.getLogger(__name__)
@@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
#
# before we do that, we persist some other events to act as state.
- self._inject_visibility("@admin:hs", "joined")
+ self.get_success(
+ inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined")
+ )
for i in range(10):
- self._inject_room_member("@resident%i:hs" % i)
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@resident%i:hs" % i,
+ "join",
+ )
+ )
events_to_filter = []
for i in range(10):
- user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
- evt = self._inject_room_member(user, extra_content={"a": "b"})
+ evt = self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"),
+ "join",
+ extra_content={"a": "b"},
+ )
+ )
events_to_filter.append(evt)
filtered = self.get_success(
@@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
def test_filter_outlier(self) -> None:
# outlier events must be returned, for the good of the collective federation
- self._inject_room_member("@resident:remote_hs")
- self._inject_visibility("@resident:remote_hs", "joined")
+ self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@resident:remote_hs",
+ "join",
+ )
+ )
+ self.get_success(
+ inject_visibility_event(
+ self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined"
+ )
+ )
outlier = self._inject_outlier()
self.assertEqual(
@@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
)
# it should also work when there are other events in the list
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
filtered = self.get_success(
filter_events_for_server(
@@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
# change in the middle of them.
events_to_filter = []
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@erased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_room_member("@joiner:remote_hs")
+ evt = self.get_success(
+ inject_member_event(
+ self.hs,
+ TEST_ROOM_ID,
+ "@joiner:remote_hs",
+ "join",
+ )
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@unerased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+ )
events_to_filter.append(evt)
- evt = self._inject_message("@erased:local_hs")
+ evt = self.get_success(
+ inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+ )
events_to_filter.append(evt)
# the erasey user gets erased
@@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
for i in (1, 4):
self.assertNotIn("body", filtered[i].content)
- def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
- content = {"history_visibility": visibility}
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.history_visibility",
- "sender": user_id,
- "state_key": "",
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
- def _inject_room_member(
- self,
- user_id: str,
- membership: str = "join",
- extra_content: Optional[JsonDict] = None,
- ) -> EventBase:
- content = {"membership": membership}
- content.update(extra_content or {})
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.member",
- "sender": user_id,
- "state_key": user_id,
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
-
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
- def _inject_message(
- self, user_id: str, content: Optional[JsonDict] = None
- ) -> EventBase:
- if content is None:
- content = {"body": "testytest", "msgtype": "m.text"}
- builder = self.event_builder_factory.for_room_version(
- RoomVersions.V1,
- {
- "type": "m.room.message",
- "sender": user_id,
- "room_id": TEST_ROOM_ID,
- "content": content,
- },
- )
-
- event, unpersisted_context = self.get_success(
- self.event_creation_handler.create_new_client_event(builder)
- )
- context = self.get_success(unpersisted_context.persist(event))
-
- self.get_success(self._persistence.persist_event(event, context))
- return event
-
def _inject_outlier(self) -> EventBase:
builder = self.event_builder_factory.for_room_version(
RoomVersions.V1,
@@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
return event
-class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+class FilterEventsForClientTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def test_joined_history_visibility(self) -> None:
+ # User joins and leaves room. Should be able to see the join and leave,
+ # and messages sent between the two, but not before or after.
+
+ self.register_user("resident", "p1")
+ resident_token = self.login("resident", "p1")
+ room_id = self.helper.create_room_as("resident", tok=resident_token)
+
+ self.get_success(
+ inject_visibility_event(self.hs, room_id, "@resident:test", "joined")
+ )
+ before_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="before")
+ )
+ join_event = self.get_success(
+ inject_member_event(self.hs, room_id, "@joiner:test", "join")
+ )
+ during_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="during")
+ )
+ leave_event = self.get_success(
+ inject_member_event(self.hs, room_id, "@joiner:test", "leave")
+ )
+ after_event = self.get_success(
+ inject_message_event(self.hs, room_id, "@resident:test", body="after")
+ )
+
+ # We have to reload the events from the db, to ensure that prev_content is
+ # populated.
+ events_to_filter = [
+ self.get_success(
+ self.hs.get_storage_controllers().main.get_event(
+ e.event_id,
+ get_prev_content=True,
+ )
+ )
+ for e in [
+ before_event,
+ join_event,
+ during_event,
+ leave_event,
+ after_event,
+ ]
+ ]
+
+ # Now run the events through the filter, and check that we can see the events
+ # we expect, and that the membership prop is as expected.
+ #
+ # We deliberately do the queries for both users upfront; this simulates
+ # concurrent queries on the server, and helps ensure that we aren't
+ # accidentally serving the same event object (with the same unsigned.membership
+ # property) to both users.
+ joiner_filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@joiner:test",
+ events_to_filter,
+ msc4115_membership_on_events=True,
+ )
+ )
+ resident_filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@resident:test",
+ events_to_filter,
+ msc4115_membership_on_events=True,
+ )
+ )
+
+ # The joiner should be able to seem the join and leave,
+ # and messages sent between the two, but not before or after.
+ self.assertEqual(
+ [e.event_id for e in [join_event, during_event, leave_event]],
+ [e.event_id for e in joiner_filtered_events],
+ )
+ self.assertEqual(
+ ["join", "join", "leave"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in joiner_filtered_events
+ ],
+ )
+
+ # The resident user should see all the events.
+ self.assertEqual(
+ [
+ e.event_id
+ for e in [
+ before_event,
+ join_event,
+ during_event,
+ leave_event,
+ after_event,
+ ]
+ ],
+ [e.event_id for e in resident_filtered_events],
+ )
+ self.assertEqual(
+ ["join", "join", "join", "join", "join"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in resident_filtered_events
+ ],
+ )
+
+
+class FilterEventsOutOfBandEventsForClientTestCase(
+ unittest.FederatingHomeserverTestCase
+):
def test_out_of_band_invite_rejection(self) -> None:
# this is where we have received an invite event over federation, and then
# rejected it.
@@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
)
# the invited user should be able to see both the invite and the rejection
+ filtered_events = self.get_success(
+ filter_events_for_client(
+ self.hs.get_storage_controllers(),
+ "@user:test",
+ [invite_event, reject_event],
+ msc4115_membership_on_events=True,
+ )
+ )
self.assertEqual(
- self.get_success(
- filter_events_for_client(
- self.hs.get_storage_controllers(),
- "@user:test",
- [invite_event, reject_event],
- )
- ),
- [invite_event, reject_event],
+ [e.event_id for e in filtered_events],
+ [e.event_id for e in [invite_event, reject_event]],
+ )
+ self.assertEqual(
+ ["invite", "leave"],
+ [
+ e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+ for e in filtered_events
+ ],
)
# other users should see neither
@@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
self.hs.get_storage_controllers(),
"@other:test",
[invite_event, reject_event],
+ msc4115_membership_on_events=True,
)
),
[],
)
+
+
+async def inject_visibility_event(
+ hs: HomeServer,
+ room_id: str,
+ sender: str,
+ visibility: str,
+) -> EventBase:
+ return await inject_event(
+ hs,
+ type="m.room.history_visibility",
+ sender=sender,
+ state_key="",
+ room_id=room_id,
+ content={"history_visibility": visibility},
+ )
+
+
+async def inject_message_event(
+ hs: HomeServer,
+ room_id: str,
+ sender: str,
+ body: Optional[str] = "testytest",
+) -> EventBase:
+ return await inject_event(
+ hs,
+ type="m.room.message",
+ sender=sender,
+ room_id=room_id,
+ content={"body": body, "msgtype": "m.text"},
+ )
|