diff --git a/synapse/visibility.py b/synapse/visibility.py
index 128413c8aa..dc7b6e4065 100644
--- a/synapse/visibility.py
+++ b/synapse/visibility.py
@@ -27,7 +27,6 @@ from typing import (
Final,
FrozenSet,
List,
- Mapping,
Optional,
Sequence,
Set,
@@ -48,6 +47,7 @@ from synapse.events.utils import clone_event, prune_event
from synapse.logging.opentracing import trace
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
+from synapse.synapse_rust.events import event_visible_to_server
from synapse.types import RetentionPolicy, StateMap, StrCollection, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import Clock
@@ -135,9 +135,9 @@ async def filter_events_for_client(
retention_policies: Dict[str, RetentionPolicy] = {}
for room_id in room_ids:
- retention_policies[room_id] = (
- await storage.main.get_retention_policy_for_room(room_id)
- )
+ retention_policies[
+ room_id
+ ] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event: EventBase) -> Optional[EventBase]:
state_after_event = event_id_to_state.get(event.event_id)
@@ -628,17 +628,6 @@ async def filter_events_for_server(
"""Filter a list of events based on whether the target server is allowed to
see them.
- For a fully stated room, the target server is allowed to see an event E if:
- - the state at E has world readable or shared history vis, OR
- - the state at E says that the target server is in the room.
-
- For a partially stated room, the target server is allowed to see E if:
- - E was created by this homeserver, AND:
- - the partial state at E has world readable or shared history vis, OR
- - the partial state at E says that the target server is in the room.
-
- TODO: state before or state after?
-
Args:
storage
target_server_name
@@ -655,35 +644,6 @@ async def filter_events_for_server(
The filtered events.
"""
- def is_sender_erased(event: EventBase, erased_senders: Mapping[str, bool]) -> bool:
- if erased_senders and erased_senders[event.sender]:
- logger.info("Sender of %s has been erased, redacting", event.event_id)
- return True
- return False
-
- def check_event_is_visible(
- visibility: str, memberships: StateMap[EventBase]
- ) -> bool:
- if visibility not in (HistoryVisibility.INVITED, HistoryVisibility.JOINED):
- return True
-
- # We now loop through all membership events looking for
- # membership states for the requesting server to determine
- # if the server is either in the room or has been invited
- # into the room.
- for ev in memberships.values():
- assert get_domain_from_id(ev.state_key) == target_server_name
-
- memtype = ev.membership
- if memtype == Membership.JOIN:
- return True
- elif memtype == Membership.INVITE:
- if visibility == HistoryVisibility.INVITED:
- return True
-
- # server has no users in the room: redact
- return False
-
if filter_out_erased_senders:
erased_senders = await storage.main.are_users_erased(e.sender for e in events)
else:
@@ -726,20 +686,16 @@ async def filter_events_for_server(
target_server_name,
)
- def include_event_in_output(e: EventBase) -> bool:
- erased = is_sender_erased(e, erased_senders)
- visible = check_event_is_visible(
- event_to_history_vis[e.event_id], event_to_memberships.get(e.event_id, {})
- )
-
- if e.event_id in partial_state_invisible_event_ids:
- visible = False
-
- return visible and not erased
-
to_return = []
for e in events:
- if include_event_in_output(e):
+ if event_visible_to_server(
+ sender=e.sender,
+ target_server_name=target_server_name,
+ history_visibility=event_to_history_vis[e.event_id],
+ erased_senders=erased_senders,
+ partial_state_invisible=e.event_id in partial_state_invisible_event_ids,
+ memberships=list(event_to_memberships.get(e.event_id, {}).values()),
+ ):
to_return.append(e)
elif redact:
to_return.append(prune_event(e))
@@ -796,7 +752,7 @@ async def _event_to_history_vis(
async def _event_to_memberships(
storage: StorageControllers, events: Collection[EventBase], server_name: str
-) -> Dict[str, StateMap[EventBase]]:
+) -> Dict[str, StateMap[Tuple[str, str]]]:
"""Get the remote membership list at each of the given events
Returns a map from event id to state map, which will contain only membership events
@@ -849,7 +805,7 @@ async def _event_to_memberships(
return {
e_id: {
- key: event_map[inner_e_id]
+ key: (event_map[inner_e_id].state_key, event_map[inner_e_id].membership)
for key, inner_e_id in key_to_eid.items()
if inner_e_id in event_map
}
|