summary refs log tree commit diff
path: root/synapse/visibility.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/visibility.py72
1 files changed, 14 insertions, 58 deletions
diff --git a/synapse/visibility.py b/synapse/visibility.py

index 128413c8aa..dc7b6e4065 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py
@@ -27,7 +27,6 @@ from typing import ( Final, FrozenSet, List, - Mapping, Optional, Sequence, Set, @@ -48,6 +47,7 @@ from synapse.events.utils import clone_event, prune_event from synapse.logging.opentracing import trace from synapse.storage.controllers import StorageControllers from synapse.storage.databases.main import DataStore +from synapse.synapse_rust.events import event_visible_to_server from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id from synapse.types.state import StateFilter from synapse.util import Clock @@ -135,9 +135,9 @@ async def filter_events_for_client( retention_policies: Dict[str, RetentionPolicy] = {} for room_id in room_ids: - retention_policies[room_id] = ( - await storage.main.get_retention_policy_for_room(room_id) - ) + retention_policies[ + room_id + ] = await storage.main.get_retention_policy_for_room(room_id) def allowed(event: EventBase) -> Optional[EventBase]: state_after_event = event_id_to_state.get(event.event_id) @@ -628,17 +628,6 @@ async def filter_events_for_server( """Filter a list of events based on whether the target server is allowed to see them. - For a fully stated room, the target server is allowed to see an event E if: - - the state at E has world readable or shared history vis, OR - - the state at E says that the target server is in the room. - - For a partially stated room, the target server is allowed to see E if: - - E was created by this homeserver, AND: - - the partial state at E has world readable or shared history vis, OR - - the partial state at E says that the target server is in the room. - - TODO: state before or state after? - Args: storage target_server_name @@ -655,35 +644,6 @@ async def filter_events_for_server( The filtered events. """ - def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool: - if erased_senders and erased_senders[event.sender]: - logger.info("Sender of %s has been erased, redacting", event.event_id) - return True - return False - - def check_event_is_visible( - visibility: str, memberships: StateMap[EventBase] - ) -> bool: - if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED): - return True - - # We now loop through all membership events looking for - # membership states for the requesting server to determine - # if the server is either in the room or has been invited - # into the room. - for ev in memberships.values(): - assert get_domain_from_id(ev.state_key) == target_server_name - - memtype = ev.membership - if memtype == Membership.JOIN: - return True - elif memtype == Membership.INVITE: - if visibility == HistoryVisibility.INVITED: - return True - - # server has no users in the room: redact - return False - if filter_out_erased_senders: erased_senders = await storage.main.are_users_erased(e.sender for e in events) else: @@ -726,20 +686,16 @@ async def filter_events_for_server( target_server_name, ) - def include_event_in_output(e: EventBase) -> bool: - erased = is_sender_erased(e, erased_senders) - visible = check_event_is_visible( - event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {}) - ) - - if e.event_id in partial_state_invisible_event_ids: - visible = False - - return visible and not erased - to_return = [] for e in events: - if include_event_in_output(e): + if event_visible_to_server( + sender=e.sender, + target_server_name=target_server_name, + history_visibility=event_to_history_vis[e.event_id], + erased_senders=erased_senders, + partial_state_invisible=e.event_id in partial_state_invisible_event_ids, + memberships=list(event_to_memberships.get(e.event_id, {}).values()), + ): to_return.append(e) elif redact: to_return.append(prune_event(e)) @@ -796,7 +752,7 @@ async def _event_to_history_vis( async def _event_to_memberships( storage: StorageControllers, events: Collection[EventBase], server_name: str -) -> Dict[str, StateMap[EventBase]]: +) -> Dict[str, StateMap[Tuple[str, str]]]: """Get the remote membership list at each of the given events Returns a map from event id to state map, which will contain only membership events @@ -849,7 +805,7 @@ async def _event_to_memberships( return { e_id: { - key: event_map[inner_e_id] + key: (event_map[inner_e_id].state_key, event_map[inner_e_id].membership) for key, inner_e_id in key_to_eid.items() if inner_e_id in event_map }