diff options
Diffstat (limited to 'synapse/visibility.py')
-rw-r--r-- | synapse/visibility.py | 73 |
1 files changed, 58 insertions, 15 deletions
diff --git a/synapse/visibility.py b/synapse/visibility.py index d1d478129f..09a947ef15 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -36,10 +36,15 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, HistoryVisibility, Membership +from synapse.api.constants import ( + EventTypes, + EventUnsignedContentFields, + HistoryVisibility, + Membership, +) from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.events.utils import prune_event +from synapse.events.utils import clone_event, prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore @@ -77,6 +82,7 @@ async def filter_events_for_client( is_peeking: bool = False, always_include_ids: FrozenSet[str] = frozenset(), filter_send_to_client: bool = True, + msc4115_membership_on_events: bool = False, ) -> List[EventBase]: """ Check which events a user is allowed to see. If the user can see the event but its @@ -95,9 +101,12 @@ async def filter_events_for_client( filter_send_to_client: Whether we're checking an event that's going to be sent to a client. This might not always be the case since this function can also be called to check whether a user can see the state at a given point. + msc4115_membership_on_events: Whether to include the requesting user's + membership in the "unsigned" data, per MSC4115. Returns: - The filtered events. + The filtered events. If `msc4115_membership_on_events` is true, the `unsigned` + data is annotated with the membership state of `user_id` at each event. """ # Filter out events that have been soft failed so that we don't relay them # to clients. @@ -134,7 +143,8 @@ async def filter_events_for_client( ) def allowed(event: EventBase) -> Optional[EventBase]: - return _check_client_allowed_to_see_event( + state_after_event = event_id_to_state.get(event.event_id) + filtered = _check_client_allowed_to_see_event( user_id=user_id, event=event, clock=storage.main.clock, @@ -142,13 +152,45 @@ async def filter_events_for_client( sender_ignored=event.sender in ignore_list, always_include_ids=always_include_ids, retention_policy=retention_policies[room_id], - state=event_id_to_state.get(event.event_id), + state=state_after_event, is_peeking=is_peeking, sender_erased=erased_senders.get(event.sender, False), ) + if filtered is None: + return None + + if not msc4115_membership_on_events: + return filtered + + # Annotate the event with the user's membership after the event. + # + # Normally we just look in `state_after_event`, but if the event is an outlier + # we won't have such a state. The only outliers that are returned here are the + # user's own membership event, so we can just inspect that. + + user_membership_event: Optional[EventBase] + if event.type == EventTypes.Member and event.state_key == user_id: + user_membership_event = event + elif state_after_event is not None: + user_membership_event = state_after_event.get((EventTypes.Member, user_id)) + else: + # unreachable! + raise Exception("Missing state for event that is not user's own membership") + + user_membership = ( + user_membership_event.membership + if user_membership_event + else Membership.LEAVE + ) - # Check each event: gives an iterable of None or (a potentially modified) - # EventBase. + # Copy the event before updating the unsigned data: this shouldn't be persisted + # to the cache! + cloned = clone_event(filtered) + cloned.unsigned[EventUnsignedContentFields.MSC4115_MEMBERSHIP] = user_membership + + return cloned + + # Check each event: gives an iterable of None or (a modified) EventBase. filtered_events = map(allowed, events) # Turn it into a list and remove None entries before returning. @@ -396,7 +438,13 @@ def _check_client_allowed_to_see_event( @attr.s(frozen=True, slots=True, auto_attribs=True) class _CheckMembershipReturn: - "Return value of _check_membership" + """Return value of `_check_membership`. + + Attributes: + allowed: Whether the user should be allowed to see the event. + joined: Whether the user was joined to the room at the event. + """ + allowed: bool joined: bool @@ -408,12 +456,7 @@ def _check_membership( state: StateMap[EventBase], is_peeking: bool, ) -> _CheckMembershipReturn: - """Check whether the user can see the event due to their membership - - Returns: - True if they can, False if they can't, plus the membership of the user - at the event. - """ + """Check whether the user can see the event due to their membership""" # If the event is the user's own membership event, use the 'most joined' # membership membership = None @@ -435,7 +478,7 @@ def _check_membership( if membership == "leave" and ( prev_membership == "join" or prev_membership == "invite" ): - return _CheckMembershipReturn(True, membership == Membership.JOIN) + return _CheckMembershipReturn(True, False) new_priority = MEMBERSHIP_PRIORITY.index(membership) old_priority = MEMBERSHIP_PRIORITY.index(prev_membership) |