summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2024-04-29 15:22:13 +0100
committerGitHub <noreply@github.com>2024-04-29 15:22:13 +0100
commitb548f7803a9b7ba51a66d47ddb9bb69dce541a48 (patch)
treeedf0a111f84a08464c6551c4d75e3667fa13c60b
parentUpdate tornado 6.2 -> 6.4 (#17131) (diff)
downloadsynapse-b548f7803a9b7ba51a66d47ddb9bb69dce541a48.tar.xz
Add support for MSC4115 (#17104)
Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
-rw-r--r--changelog.d/17104.feature1
-rw-r--r--docker/complement/conf/workers-shared-extra.yaml.j24
-rw-r--r--rust/src/events/internal_metadata.rs9
-rw-r--r--synapse/api/constants.py7
-rw-r--r--synapse/config/experimental.py4
-rw-r--r--synapse/events/utils.py30
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/events.py7
-rw-r--r--synapse/handlers/initial_sync.py7
-rw-r--r--synapse/handlers/pagination.py1
-rw-r--r--synapse/handlers/relations.py3
-rw-r--r--synapse/handlers/room.py1
-rw-r--r--synapse/handlers/search.py20
-rw-r--r--synapse/handlers/sync.py2
-rw-r--r--synapse/notifier.py1
-rw-r--r--synapse/push/mailer.py5
-rw-r--r--synapse/visibility.py73
-rw-r--r--tests/events/test_utils.py24
-rw-r--r--tests/rest/client/test_retention.py7
-rw-r--r--tests/test_visibility.py320
20 files changed, 407 insertions, 125 deletions
diff --git a/changelog.d/17104.feature b/changelog.d/17104.feature
new file mode 100644
index 0000000000..1c2355e155
--- /dev/null
+++ b/changelog.d/17104.feature
@@ -0,0 +1 @@
+Add support for MSC4115 (membership metadata on events).
diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2
index 32eada4419..a2c378f547 100644
--- a/docker/complement/conf/workers-shared-extra.yaml.j2
+++ b/docker/complement/conf/workers-shared-extra.yaml.j2
@@ -92,8 +92,6 @@ allow_device_name_lookup_over_federation: true
 ## Experimental Features ##
 
 experimental_features:
-  # client-side support for partial state in /send_join responses
-  faster_joins: true
   # Enable support for polls
   msc3381_polls_enabled: true
   # Enable deleting device-specific notification settings stored in account data
@@ -105,6 +103,8 @@ experimental_features:
   # no UIA for x-signing upload for the first time
   msc3967_enabled: true
 
+  msc4115_membership_on_events: true
+
 server_notices:
   system_mxid_localpart: _server
   system_mxid_display_name: "Server Alert"
diff --git a/rust/src/events/internal_metadata.rs b/rust/src/events/internal_metadata.rs
index a53601862d..53c7b1ba61 100644
--- a/rust/src/events/internal_metadata.rs
+++ b/rust/src/events/internal_metadata.rs
@@ -20,8 +20,10 @@
 
 //! Implements the internal metadata class attached to events.
 //!
-//! The internal metadata is a bit like a `TypedDict`, in that it is stored as a
-//! JSON dict in the DB. Most events have zero, or only a few, of these keys
+//! The internal metadata is a bit like a `TypedDict`, in that most of
+//! it is stored as a JSON dict in the DB (the exceptions being `outlier`
+//! and `stream_ordering` which have their own columns in the database).
+//! Most events have zero, or only a few, of these keys
 //! set. Therefore, since we care more about memory size than performance here,
 //! we store these fields in a mapping.
 //!
@@ -234,6 +236,9 @@ impl EventInternalMetadata {
         self.clone()
     }
 
+    /// Get a dict holding the data stored in the `internal_metadata` column in the database.
+    ///
+    /// Note that `outlier` and `stream_ordering` are stored in separate columns so are not returned here.
     fn get_dict(&self, py: Python<'_>) -> PyResult<PyObject> {
         let dict = PyDict::new(py);
 
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 98884b4967..0a9123c56b 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -234,6 +234,13 @@ class EventContentFields:
     TO_DEVICE_MSGID: Final = "org.matrix.msgid"
 
 
+class EventUnsignedContentFields:
+    """Fields found inside the 'unsigned' data on events"""
+
+    # Requesting user's membership, per MSC4115
+    MSC4115_MEMBERSHIP: Final = "io.element.msc4115.membership"
+
+
 class RoomTypes:
     """Understood values of the room_type field of m.room.create events."""
 
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index baa3580f29..749452ce93 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -432,3 +432,7 @@ class ExperimentalConfig(Config):
                 "You cannot have MSC4108 both enabled and delegated at the same time",
                 ("experimental", "msc4108_delegation_endpoint"),
             )
+
+        self.msc4115_membership_on_events = experimental.get(
+            "msc4115_membership_on_events", False
+        )
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e0613d0dbc..0772472312 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -49,7 +49,7 @@ from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
 from synapse.types import JsonDict, Requester
 
-from . import EventBase
+from . import EventBase, make_event_from_dict
 
 if TYPE_CHECKING:
     from synapse.handlers.relations import BundledAggregations
@@ -82,17 +82,14 @@ def prune_event(event: EventBase) -> EventBase:
     """
     pruned_event_dict = prune_event_dict(event.room_version, event.get_dict())
 
-    from . import make_event_from_dict
-
     pruned_event = make_event_from_dict(
         pruned_event_dict, event.room_version, event.internal_metadata.get_dict()
     )
 
-    # copy the internal fields
+    # Copy the bits of `internal_metadata` that aren't returned by `get_dict`
     pruned_event.internal_metadata.stream_ordering = (
         event.internal_metadata.stream_ordering
     )
-
     pruned_event.internal_metadata.outlier = event.internal_metadata.outlier
 
     # Mark the event as redacted
@@ -101,6 +98,29 @@ def prune_event(event: EventBase) -> EventBase:
     return pruned_event
 
 
+def clone_event(event: EventBase) -> EventBase:
+    """Take a copy of the event.
+
+    This is mostly useful because it does a *shallow* copy of the `unsigned` data,
+    which means it can then be updated without corrupting the in-memory cache. Note that
+    other properties of the event, such as `content`, are *not* (currently) copied here.
+    """
+    # XXX: We rely on at least one of `event.get_dict()` and `make_event_from_dict()`
+    #   making a copy of `unsigned`. Currently, both do, though I don't really know why.
+    #   Still, as long as they do, there's not much point doing yet another copy here.
+    new_event = make_event_from_dict(
+        event.get_dict(), event.room_version, event.internal_metadata.get_dict()
+    )
+
+    # Copy the bits of `internal_metadata` that aren't returned by `get_dict`.
+    new_event.internal_metadata.stream_ordering = (
+        event.internal_metadata.stream_ordering
+    )
+    new_event.internal_metadata.outlier = event.internal_metadata.outlier
+
+    return new_event
+
+
 def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
     """Redacts the event_dict in the same way as `prune_event`, except it
     operates on dicts rather than event objects
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 360614e25b..702d40332c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -42,6 +42,7 @@ class AdminHandler:
         self._device_handler = hs.get_device_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
+        self._hs_config = hs.config
         self._msc3866_enabled = hs.config.experimental.msc3866.enabled
 
     async def get_whois(self, user: UserID) -> JsonMapping:
@@ -217,7 +218,10 @@ class AdminHandler:
                 )
 
                 events = await filter_events_for_client(
-                    self._storage_controllers, user_id, events
+                    self._storage_controllers,
+                    user_id,
+                    events,
+                    msc4115_membership_on_events=self._hs_config.experimental.msc4115_membership_on_events,
                 )
 
                 writer.write_events(room_id, events)
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index c3fee74a98..09d553cff1 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -148,6 +148,7 @@ class EventHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
+        self._config = hs.config
 
     async def get_event(
         self,
@@ -189,7 +190,11 @@ class EventHandler:
         is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
-            self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
+            self._storage_controllers,
+            user.to_string(),
+            [event],
+            is_peeking=is_peeking,
+            msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
         )
 
         if not filtered:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index bcc5b285ac..d99fc4bec0 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -221,7 +221,10 @@ class InitialSyncHandler:
                 ).addErrback(unwrapFirstError)
 
                 messages = await filter_events_for_client(
-                    self._storage_controllers, user_id, messages
+                    self._storage_controllers,
+                    user_id,
+                    messages,
+                    msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
                 )
 
                 start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
@@ -380,6 +383,7 @@ class InitialSyncHandler:
             requester.user.to_string(),
             messages,
             is_peeking=is_peeking,
+            msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
         )
 
         start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
@@ -494,6 +498,7 @@ class InitialSyncHandler:
             requester.user.to_string(),
             messages,
             is_peeking=is_peeking,
+            msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
         )
 
         start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index cd3a9088cd..6617105cdb 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -623,6 +623,7 @@ class PaginationHandler:
                 user_id,
                 events,
                 is_peeking=(member_event_id is None),
+                msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
             )
 
         # if after the filter applied there are no more events
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 931ac0c813..c5cee8860b 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -95,6 +95,7 @@ class RelationsHandler:
         self._event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
         self._event_creation_handler = hs.get_event_creation_handler()
+        self._config = hs.config
 
     async def get_relations(
         self,
@@ -163,6 +164,7 @@ class RelationsHandler:
             user_id,
             events,
             is_peeking=(member_event_id is None),
+            msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
         )
 
         # The relations returned for the requested event do include their
@@ -608,6 +610,7 @@ class RelationsHandler:
             user_id,
             events,
             is_peeking=(member_event_id is None),
+            msc4115_membership_on_events=self._config.experimental.msc4115_membership_on_events,
         )
 
         aggregations = await self.get_bundled_aggregations(
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 5e81a51638..51739a2653 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1476,6 +1476,7 @@ class RoomContextHandler:
                 user.to_string(),
                 events,
                 is_peeking=is_peeking,
+                msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
             )
 
         event = await self.store.get_event(
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 19c5a2f257..fdbe98de3b 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -480,7 +480,10 @@ class SearchHandler:
         filtered_events = await search_filter.filter([r["event"] for r in results])
 
         events = await filter_events_for_client(
-            self._storage_controllers, user.to_string(), filtered_events
+            self._storage_controllers,
+            user.to_string(),
+            filtered_events,
+            msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
         )
 
         events.sort(key=lambda e: -rank_map[e.event_id])
@@ -579,7 +582,10 @@ class SearchHandler:
             filtered_events = await search_filter.filter([r["event"] for r in results])
 
             events = await filter_events_for_client(
-                self._storage_controllers, user.to_string(), filtered_events
+                self._storage_controllers,
+                user.to_string(),
+                filtered_events,
+                msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
             )
 
             room_events.extend(events)
@@ -664,11 +670,17 @@ class SearchHandler:
             )
 
             events_before = await filter_events_for_client(
-                self._storage_controllers, user.to_string(), res.events_before
+                self._storage_controllers,
+                user.to_string(),
+                res.events_before,
+                msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
             )
 
             events_after = await filter_events_for_client(
-                self._storage_controllers, user.to_string(), res.events_after
+                self._storage_controllers,
+                user.to_string(),
+                res.events_after,
+                msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
             )
 
             context: JsonDict = {
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index a6d54ee4b8..8ff45a3353 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -596,6 +596,7 @@ class SyncHandler:
                     sync_config.user.to_string(),
                     recents,
                     always_include_ids=current_state_ids,
+                    msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events,
                 )
                 log_kv({"recents_after_visibility_filtering": len(recents)})
             else:
@@ -681,6 +682,7 @@ class SyncHandler:
                     sync_config.user.to_string(),
                     loaded_recents,
                     always_include_ids=current_state_ids,
+                    msc4115_membership_on_events=self.hs_config.experimental.msc4115_membership_on_events,
                 )
 
                 loaded_recents = []
diff --git a/synapse/notifier.py b/synapse/notifier.py
index e87333a80a..7c1cd3b5f2 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -721,6 +721,7 @@ class Notifier:
                         user.to_string(),
                         new_events,
                         is_peeking=is_peeking,
+                        msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
                     )
                 elif keyname == StreamKeyType.PRESENCE:
                     now = self.clock.time_msec()
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 7c15eb7440..49ce9d6dda 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -529,7 +529,10 @@ class Mailer:
         }
 
         the_events = await filter_events_for_client(
-            self._storage_controllers, user_id, results.events_before
+            self._storage_controllers,
+            user_id,
+            results.events_before,
+            msc4115_membership_on_events=self.hs.config.experimental.msc4115_membership_on_events,
         )
         the_events.append(notif_event)
 
diff --git a/synapse/visibility.py b/synapse/visibility.py
index d1d478129f..09a947ef15 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -36,10 +36,15 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventTypes, HistoryVisibility, Membership
+from synapse.api.constants import (
+    EventTypes,
+    EventUnsignedContentFields,
+    HistoryVisibility,
+    Membership,
+)
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.events.utils import prune_event
+from synapse.events.utils import clone_event, prune_event
 from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
@@ -77,6 +82,7 @@ async def filter_events_for_client(
     is_peeking: bool = False,
     always_include_ids: FrozenSet[str] = frozenset(),
     filter_send_to_client: bool = True,
+    msc4115_membership_on_events: bool = False,
 ) -> List[EventBase]:
     """
     Check which events a user is allowed to see. If the user can see the event but its
@@ -95,9 +101,12 @@ async def filter_events_for_client(
         filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.
+        msc4115_membership_on_events: Whether to include the requesting user's
+            membership in the "unsigned" data, per MSC4115.
 
     Returns:
-        The filtered events.
+        The filtered events. If `msc4115_membership_on_events` is true, the `unsigned`
+        data is annotated with the membership state of `user_id` at each event.
     """
     # Filter out events that have been soft failed so that we don't relay them
     # to clients.
@@ -134,7 +143,8 @@ async def filter_events_for_client(
             )
 
     def allowed(event: EventBase) -> Optional[EventBase]:
-        return _check_client_allowed_to_see_event(
+        state_after_event = event_id_to_state.get(event.event_id)
+        filtered = _check_client_allowed_to_see_event(
             user_id=user_id,
             event=event,
             clock=storage.main.clock,
@@ -142,13 +152,45 @@ async def filter_events_for_client(
             sender_ignored=event.sender in ignore_list,
             always_include_ids=always_include_ids,
             retention_policy=retention_policies[room_id],
-            state=event_id_to_state.get(event.event_id),
+            state=state_after_event,
             is_peeking=is_peeking,
             sender_erased=erased_senders.get(event.sender, False),
         )
+        if filtered is None:
+            return None
+
+        if not msc4115_membership_on_events:
+            return filtered
+
+        # Annotate the event with the user's membership after the event.
+        #
+        # Normally we just look in `state_after_event`, but if the event is an outlier
+        # we won't have such a state. The only outliers that are returned here are the
+        # user's own membership event, so we can just inspect that.
+
+        user_membership_event: Optional[EventBase]
+        if event.type == EventTypes.Member and event.state_key == user_id:
+            user_membership_event = event
+        elif state_after_event is not None:
+            user_membership_event = state_after_event.get((EventTypes.Member, user_id))
+        else:
+            # unreachable!
+            raise Exception("Missing state for event that is not user's own membership")
+
+        user_membership = (
+            user_membership_event.membership
+            if user_membership_event
+            else Membership.LEAVE
+        )
 
-    # Check each event: gives an iterable of None or (a potentially modified)
-    # EventBase.
+        # Copy the event before updating the unsigned data: this shouldn't be persisted
+        # to the cache!
+        cloned = clone_event(filtered)
+        cloned.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] = user_membership
+
+        return cloned
+
+    # Check each event: gives an iterable of None or (a modified) EventBase.
     filtered_events = map(allowed, events)
 
     # Turn it into a list and remove None entries before returning.
@@ -396,7 +438,13 @@ def _check_client_allowed_to_see_event(
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class _CheckMembershipReturn:
-    "Return value of _check_membership"
+    """Return value of `_check_membership`.
+
+    Attributes:
+        allowed: Whether the user should be allowed to see the event.
+        joined: Whether the user was joined to the room at the event.
+    """
+
     allowed: bool
     joined: bool
 
@@ -408,12 +456,7 @@ def _check_membership(
     state: StateMap[EventBase],
     is_peeking: bool,
 ) -> _CheckMembershipReturn:
-    """Check whether the user can see the event due to their membership
-
-    Returns:
-        True if they can, False if they can't, plus the membership of the user
-        at the event.
-    """
+    """Check whether the user can see the event due to their membership"""
     # If the event is the user's own membership event, use the 'most joined'
     # membership
     membership = None
@@ -435,7 +478,7 @@ def _check_membership(
         if membership == "leave" and (
             prev_membership == "join" or prev_membership == "invite"
         ):
-            return _CheckMembershipReturn(True, membership == Membership.JOIN)
+            return _CheckMembershipReturn(True, False)
 
         new_priority = MEMBERSHIP_PRIORITY.index(membership)
         old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
diff --git a/tests/events/test_utils.py b/tests/events/test_utils.py
index cf81bcf52c..d5ac66a6ed 100644
--- a/tests/events/test_utils.py
+++ b/tests/events/test_utils.py
@@ -32,6 +32,7 @@ from synapse.events.utils import (
     PowerLevelsContent,
     SerializeEventConfig,
     _split_field,
+    clone_event,
     copy_and_fixup_power_levels_contents,
     maybe_upsert_event_field,
     prune_event,
@@ -611,6 +612,29 @@ class PruneEventTestCase(stdlib_unittest.TestCase):
         )
 
 
+class CloneEventTestCase(stdlib_unittest.TestCase):
+    def test_unsigned_is_copied(self) -> None:
+        original = make_event_from_dict(
+            {
+                "type": "A",
+                "event_id": "$test:domain",
+                "unsigned": {"a": 1, "b": 2},
+            },
+            RoomVersions.V1,
+            {"txn_id": "txn"},
+        )
+        original.internal_metadata.stream_ordering = 1234
+        self.assertEqual(original.internal_metadata.stream_ordering, 1234)
+
+        cloned = clone_event(original)
+        cloned.unsigned["b"] = 3
+
+        self.assertEqual(original.unsigned, {"a": 1, "b": 2})
+        self.assertEqual(cloned.unsigned, {"a": 1, "b": 3})
+        self.assertEqual(cloned.internal_metadata.stream_ordering, 1234)
+        self.assertEqual(cloned.internal_metadata.txn_id, "txn")
+
+
 class SerializeEventTestCase(stdlib_unittest.TestCase):
     def serialize(self, ev: EventBase, fields: Optional[List[str]]) -> JsonDict:
         return serialize_event(
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index 09a5d64349..ceae40498e 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -163,7 +163,12 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         )
         self.assertEqual(2, len(events), "events retrieved from database")
         filtered_events = self.get_success(
-            filter_events_for_client(storage_controllers, self.user_id, events)
+            filter_events_for_client(
+                storage_controllers,
+                self.user_id,
+                events,
+                msc4115_membership_on_events=True,
+            )
         )
 
         # We should only get one event back.
diff --git a/tests/test_visibility.py b/tests/test_visibility.py
index e51f72d65f..3e2100eab4 100644
--- a/tests/test_visibility.py
+++ b/tests/test_visibility.py
@@ -21,13 +21,19 @@ import logging
 from typing import Optional
 from unittest.mock import patch
 
+from synapse.api.constants import EventUnsignedContentFields
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase, make_event_from_dict
 from synapse.events.snapshot import EventContext
-from synapse.types import JsonDict, create_requester
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import create_requester
 from synapse.visibility import filter_events_for_client, filter_events_for_server
 
 from tests import unittest
+from tests.test_utils.event_injection import inject_event, inject_member_event
+from tests.unittest import HomeserverTestCase
 from tests.utils import create_room
 
 logger = logging.getLogger(__name__)
@@ -56,15 +62,31 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         #
 
         # before we do that, we persist some other events to act as state.
-        self._inject_visibility("@admin:hs", "joined")
+        self.get_success(
+            inject_visibility_event(self.hs, TEST_ROOM_ID, "@admin:hs", "joined")
+        )
         for i in range(10):
-            self._inject_room_member("@resident%i:hs" % i)
+            self.get_success(
+                inject_member_event(
+                    self.hs,
+                    TEST_ROOM_ID,
+                    "@resident%i:hs" % i,
+                    "join",
+                )
+            )
 
         events_to_filter = []
 
         for i in range(10):
-            user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
-            evt = self._inject_room_member(user, extra_content={"a": "b"})
+            evt = self.get_success(
+                inject_member_event(
+                    self.hs,
+                    TEST_ROOM_ID,
+                    "@user%i:%s" % (i, "test_server" if i == 5 else "other_server"),
+                    "join",
+                    extra_content={"a": "b"},
+                )
+            )
             events_to_filter.append(evt)
 
         filtered = self.get_success(
@@ -90,8 +112,19 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
 
     def test_filter_outlier(self) -> None:
         # outlier events must be returned, for the good of the collective federation
-        self._inject_room_member("@resident:remote_hs")
-        self._inject_visibility("@resident:remote_hs", "joined")
+        self.get_success(
+            inject_member_event(
+                self.hs,
+                TEST_ROOM_ID,
+                "@resident:remote_hs",
+                "join",
+            )
+        )
+        self.get_success(
+            inject_visibility_event(
+                self.hs, TEST_ROOM_ID, "@resident:remote_hs", "joined"
+            )
+        )
 
         outlier = self._inject_outlier()
         self.assertEqual(
@@ -110,7 +143,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         )
 
         # it should also work when there are other events in the list
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
 
         filtered = self.get_success(
             filter_events_for_server(
@@ -150,19 +185,34 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         # change in the middle of them.
         events_to_filter = []
 
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@erased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_room_member("@joiner:remote_hs")
+        evt = self.get_success(
+            inject_member_event(
+                self.hs,
+                TEST_ROOM_ID,
+                "@joiner:remote_hs",
+                "join",
+            )
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@unerased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@unerased:local_hs")
+        )
         events_to_filter.append(evt)
 
-        evt = self._inject_message("@erased:local_hs")
+        evt = self.get_success(
+            inject_message_event(self.hs, TEST_ROOM_ID, "@erased:local_hs")
+        )
         events_to_filter.append(evt)
 
         # the erasey user gets erased
@@ -200,76 +250,6 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         for i in (1, 4):
             self.assertNotIn("body", filtered[i].content)
 
-    def _inject_visibility(self, user_id: str, visibility: str) -> EventBase:
-        content = {"history_visibility": visibility}
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.history_visibility",
-                "sender": user_id,
-                "state_key": "",
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
-    def _inject_room_member(
-        self,
-        user_id: str,
-        membership: str = "join",
-        extra_content: Optional[JsonDict] = None,
-    ) -> EventBase:
-        content = {"membership": membership}
-        content.update(extra_content or {})
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.member",
-                "sender": user_id,
-                "state_key": user_id,
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
-    def _inject_message(
-        self, user_id: str, content: Optional[JsonDict] = None
-    ) -> EventBase:
-        if content is None:
-            content = {"body": "testytest", "msgtype": "m.text"}
-        builder = self.event_builder_factory.for_room_version(
-            RoomVersions.V1,
-            {
-                "type": "m.room.message",
-                "sender": user_id,
-                "room_id": TEST_ROOM_ID,
-                "content": content,
-            },
-        )
-
-        event, unpersisted_context = self.get_success(
-            self.event_creation_handler.create_new_client_event(builder)
-        )
-        context = self.get_success(unpersisted_context.persist(event))
-
-        self.get_success(self._persistence.persist_event(event, context))
-        return event
-
     def _inject_outlier(self) -> EventBase:
         builder = self.event_builder_factory.for_room_version(
             RoomVersions.V1,
@@ -292,7 +272,122 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
         return event
 
 
-class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
+class FilterEventsForClientTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def test_joined_history_visibility(self) -> None:
+        # User joins and leaves room. Should be able to see the join and leave,
+        # and messages sent between the two, but not before or after.
+
+        self.register_user("resident", "p1")
+        resident_token = self.login("resident", "p1")
+        room_id = self.helper.create_room_as("resident", tok=resident_token)
+
+        self.get_success(
+            inject_visibility_event(self.hs, room_id, "@resident:test", "joined")
+        )
+        before_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="before")
+        )
+        join_event = self.get_success(
+            inject_member_event(self.hs, room_id, "@joiner:test", "join")
+        )
+        during_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="during")
+        )
+        leave_event = self.get_success(
+            inject_member_event(self.hs, room_id, "@joiner:test", "leave")
+        )
+        after_event = self.get_success(
+            inject_message_event(self.hs, room_id, "@resident:test", body="after")
+        )
+
+        # We have to reload the events from the db, to ensure that prev_content is
+        # populated.
+        events_to_filter = [
+            self.get_success(
+                self.hs.get_storage_controllers().main.get_event(
+                    e.event_id,
+                    get_prev_content=True,
+                )
+            )
+            for e in [
+                before_event,
+                join_event,
+                during_event,
+                leave_event,
+                after_event,
+            ]
+        ]
+
+        # Now run the events through the filter, and check that we can see the events
+        # we expect, and that the membership prop is as expected.
+        #
+        # We deliberately do the queries for both users upfront; this simulates
+        # concurrent queries on the server, and helps ensure that we aren't
+        # accidentally serving the same event object (with the same unsigned.membership
+        # property) to both users.
+        joiner_filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@joiner:test",
+                events_to_filter,
+                msc4115_membership_on_events=True,
+            )
+        )
+        resident_filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@resident:test",
+                events_to_filter,
+                msc4115_membership_on_events=True,
+            )
+        )
+
+        # The joiner should be able to seem the join and leave,
+        # and messages sent between the two, but not before or after.
+        self.assertEqual(
+            [e.event_id for e in [join_event, during_event, leave_event]],
+            [e.event_id for e in joiner_filtered_events],
+        )
+        self.assertEqual(
+            ["join", "join", "leave"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in joiner_filtered_events
+            ],
+        )
+
+        # The resident user should see all the events.
+        self.assertEqual(
+            [
+                e.event_id
+                for e in [
+                    before_event,
+                    join_event,
+                    during_event,
+                    leave_event,
+                    after_event,
+                ]
+            ],
+            [e.event_id for e in resident_filtered_events],
+        )
+        self.assertEqual(
+            ["join", "join", "join", "join", "join"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in resident_filtered_events
+            ],
+        )
+
+
+class FilterEventsOutOfBandEventsForClientTestCase(
+    unittest.FederatingHomeserverTestCase
+):
     def test_out_of_band_invite_rejection(self) -> None:
         # this is where we have received an invite event over federation, and then
         # rejected it.
@@ -341,15 +436,24 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
         )
 
         # the invited user should be able to see both the invite and the rejection
+        filtered_events = self.get_success(
+            filter_events_for_client(
+                self.hs.get_storage_controllers(),
+                "@user:test",
+                [invite_event, reject_event],
+                msc4115_membership_on_events=True,
+            )
+        )
         self.assertEqual(
-            self.get_success(
-                filter_events_for_client(
-                    self.hs.get_storage_controllers(),
-                    "@user:test",
-                    [invite_event, reject_event],
-                )
-            ),
-            [invite_event, reject_event],
+            [e.event_id for e in filtered_events],
+            [e.event_id for e in [invite_event, reject_event]],
+        )
+        self.assertEqual(
+            ["invite", "leave"],
+            [
+                e.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP]
+                for e in filtered_events
+            ],
         )
 
         # other users should see neither
@@ -359,7 +463,39 @@ class FilterEventsForClientTestCase(unittest.FederatingHomeserverTestCase):
                     self.hs.get_storage_controllers(),
                     "@other:test",
                     [invite_event, reject_event],
+                    msc4115_membership_on_events=True,
                 )
             ),
             [],
         )
+
+
+async def inject_visibility_event(
+    hs: HomeServer,
+    room_id: str,
+    sender: str,
+    visibility: str,
+) -> EventBase:
+    return await inject_event(
+        hs,
+        type="m.room.history_visibility",
+        sender=sender,
+        state_key="",
+        room_id=room_id,
+        content={"history_visibility": visibility},
+    )
+
+
+async def inject_message_event(
+    hs: HomeServer,
+    room_id: str,
+    sender: str,
+    body: Optional[str] = "testytest",
+) -> EventBase:
+    return await inject_event(
+        hs,
+        type="m.room.message",
+        sender=sender,
+        room_id=room_id,
+        content={"body": body, "msgtype": "m.text"},
+    )