summary refs log tree commit diff
path: root/synapse/visibility.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/visibility.py')
-rw-r--r--synapse/visibility.py73
1 files changed, 58 insertions, 15 deletions
diff --git a/synapse/visibility.py b/synapse/visibility.py
index d1d478129f..09a947ef15 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -36,10 +36,15 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventTypes, HistoryVisibility, Membership
+from synapse.api.constants import (
+    EventTypes,
+    EventUnsignedContentFields,
+    HistoryVisibility,
+    Membership,
+)
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.events.utils import prune_event
+from synapse.events.utils import clone_event, prune_event
 from synapse.logging.opentracing import trace
 from synapse.storage.controllers import StorageControllers
 from synapse.storage.databases.main import DataStore
@@ -77,6 +82,7 @@ async def filter_events_for_client(
     is_peeking: bool = False,
     always_include_ids: FrozenSet[str] = frozenset(),
     filter_send_to_client: bool = True,
+    msc4115_membership_on_events: bool = False,
 ) -> List[EventBase]:
     """
     Check which events a user is allowed to see. If the user can see the event but its
@@ -95,9 +101,12 @@ async def filter_events_for_client(
         filter_send_to_client: Whether we're checking an event that's going to be
             sent to a client. This might not always be the case since this function can
             also be called to check whether a user can see the state at a given point.
+        msc4115_membership_on_events: Whether to include the requesting user's
+            membership in the "unsigned" data, per MSC4115.
 
     Returns:
-        The filtered events.
+        The filtered events. If `msc4115_membership_on_events` is true, the `unsigned`
+        data is annotated with the membership state of `user_id` at each event.
     """
     # Filter out events that have been soft failed so that we don't relay them
     # to clients.
@@ -134,7 +143,8 @@ async def filter_events_for_client(
             )
 
     def allowed(event: EventBase) -> Optional[EventBase]:
-        return _check_client_allowed_to_see_event(
+        state_after_event = event_id_to_state.get(event.event_id)
+        filtered = _check_client_allowed_to_see_event(
             user_id=user_id,
             event=event,
             clock=storage.main.clock,
@@ -142,13 +152,45 @@ async def filter_events_for_client(
             sender_ignored=event.sender in ignore_list,
             always_include_ids=always_include_ids,
             retention_policy=retention_policies[room_id],
-            state=event_id_to_state.get(event.event_id),
+            state=state_after_event,
             is_peeking=is_peeking,
             sender_erased=erased_senders.get(event.sender, False),
         )
+        if filtered is None:
+            return None
+
+        if not msc4115_membership_on_events:
+            return filtered
+
+        # Annotate the event with the user's membership after the event.
+        #
+        # Normally we just look in `state_after_event`, but if the event is an outlier
+        # we won't have such a state. The only outliers that are returned here are the
+        # user's own membership event, so we can just inspect that.
+
+        user_membership_event: Optional[EventBase]
+        if event.type == EventTypes.Member and event.state_key == user_id:
+            user_membership_event = event
+        elif state_after_event is not None:
+            user_membership_event = state_after_event.get((EventTypes.Member, user_id))
+        else:
+            # unreachable!
+            raise Exception("Missing state for event that is not user's own membership")
+
+        user_membership = (
+            user_membership_event.membership
+            if user_membership_event
+            else Membership.LEAVE
+        )
 
-    # Check each event: gives an iterable of None or (a potentially modified)
-    # EventBase.
+        # Copy the event before updating the unsigned data: this shouldn't be persisted
+        # to the cache!
+        cloned = clone_event(filtered)
+        cloned.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] = user_membership
+
+        return cloned
+
+    # Check each event: gives an iterable of None or (a modified) EventBase.
     filtered_events = map(allowed, events)
 
     # Turn it into a list and remove None entries before returning.
@@ -396,7 +438,13 @@ def _check_client_allowed_to_see_event(
 
 @attr.s(frozen=True, slots=True, auto_attribs=True)
 class _CheckMembershipReturn:
-    "Return value of _check_membership"
+    """Return value of `_check_membership`.
+
+    Attributes:
+        allowed: Whether the user should be allowed to see the event.
+        joined: Whether the user was joined to the room at the event.
+    """
+
     allowed: bool
     joined: bool
 
@@ -408,12 +456,7 @@ def _check_membership(
     state: StateMap[EventBase],
     is_peeking: bool,
 ) -> _CheckMembershipReturn:
-    """Check whether the user can see the event due to their membership
-
-    Returns:
-        True if they can, False if they can't, plus the membership of the user
-        at the event.
-    """
+    """Check whether the user can see the event due to their membership"""
     # If the event is the user's own membership event, use the 'most joined'
     # membership
     membership = None
@@ -435,7 +478,7 @@ def _check_membership(
         if membership == "leave" and (
             prev_membership == "join" or prev_membership == "invite"
         ):
-            return _CheckMembershipReturn(True, membership == Membership.JOIN)
+            return _CheckMembershipReturn(True, False)
 
         new_priority = MEMBERSHIP_PRIORITY.index(membership)
         old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)